From de1a2317204d91ab387ae31c1cdb078a81a81af9 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 22:44:51 +0000 Subject: [PATCH 1/6] feat(mcp): add list and get tools for annotation layers and annotations Co-Authored-By: kasiazjc --- .../mcp_service/annotation_layer/__init__.py | 16 + .../mcp_service/annotation_layer/schemas.py | 306 ++++++++++++ .../annotation_layer/tool/__init__.py | 28 ++ .../tool/get_annotation_layer_info.py | 93 ++++ .../tool/get_layer_annotation_info.py | 130 ++++++ .../tool/list_annotation_layers.py | 123 +++++ .../tool/list_layer_annotations.py | 153 ++++++ superset/mcp_service/app.py | 12 + .../mcp_service/annotation_layer/__init__.py | 16 + .../annotation_layer/tool/__init__.py | 16 + .../tool/test_annotation_layer_tools.py | 434 ++++++++++++++++++ 11 files changed, 1327 insertions(+) create mode 100644 superset/mcp_service/annotation_layer/__init__.py create mode 100644 superset/mcp_service/annotation_layer/schemas.py create mode 100644 superset/mcp_service/annotation_layer/tool/__init__.py create mode 100644 superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py create mode 100644 superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py create mode 100644 superset/mcp_service/annotation_layer/tool/list_annotation_layers.py create mode 100644 superset/mcp_service/annotation_layer/tool/list_layer_annotations.py create mode 100644 tests/unit_tests/mcp_service/annotation_layer/__init__.py create mode 100644 tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py diff --git a/superset/mcp_service/annotation_layer/__init__.py b/superset/mcp_service/annotation_layer/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/superset/mcp_service/annotation_layer/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/annotation_layer/schemas.py b/superset/mcp_service/annotation_layer/schemas.py new file mode 100644 index 000000000000..11059ab75b43 --- /dev/null +++ b/superset/mcp_service/annotation_layer/schemas.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pydantic schemas for annotation layer and annotation responses.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, List, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_LAYER_COLUMNS = ["id", "name", "descr"] +DEFAULT_ANNOTATION_COLUMNS = ["id", "short_descr", "start_dttm", "end_dttm"] + + +class AnnotationLayerFilter(ColumnOperator): + """Filter object for annotation layer listing.""" + + col: Literal["name"] = Field( + ..., + description="Column to filter on. Supported: 'name'.", + ) + opr: ColumnOperatorEnum = Field(..., description="Filter operator.") + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by." + ) + + +class AnnotationFilter(ColumnOperator): + """Filter object for annotation listing.""" + + col: Literal["short_descr"] = Field( + ..., + description="Column to filter on. Supported: 'short_descr'.", + ) + opr: ColumnOperatorEnum = Field(..., description="Filter operator.") + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by." + ) + + +class AnnotationLayerInfo(BaseModel): + id: int | None = Field(None, description="Annotation layer ID") + name: str | None = Field(None, description="Annotation layer name") + descr: str | None = Field(None, description="Annotation layer description") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class AnnotationLayerList(BaseModel): + annotation_layers: List[AnnotationLayerInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field(default_factory=list) + columns_loaded: List[str] = Field(default_factory=list) + columns_available: List[str] = Field(default_factory=list) + sortable_columns: List[str] = Field(default_factory=list) + filters_applied: List[AnnotationLayerFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListAnnotationLayersRequest(BaseModel): + """Request schema for list_annotation_layers.""" + + filters: Annotated[ + List[AnnotationLayerFilter], + Field( + default_factory=list, + description="List of filter objects. Cannot be combined with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="Columns to include in the response.", + ), + ] + search: Annotated[ + str | None, + Field(default=None, description="Text search across name and description."), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by.") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction."), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)."), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE}).", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[AnnotationLayerFilter]: + return parse_json_or_model_list(v, AnnotationLayerFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListAnnotationLayersRequest": + if self.search and self.filters: + raise ValueError("Cannot use both 'search' and 'filters' simultaneously.") + return self + + +class GetAnnotationLayerInfoRequest(BaseModel): + """Request schema for get_annotation_layer_info.""" + + id: Annotated[int, Field(description="Annotation layer ID.")] + + +class AnnotationInfo(BaseModel): + id: int | None = Field(None, description="Annotation ID") + short_descr: str | None = Field(None, description="Short description") + long_descr: str | None = Field(None, description="Long description") + start_dttm: str | datetime | None = Field(None, description="Start datetime") + end_dttm: str | datetime | None = Field(None, description="End datetime") + json_metadata: str | None = Field(None, description="JSON metadata") + layer_id: int | None = Field(None, description="Parent annotation layer ID") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class AnnotationList(BaseModel): + annotations: List[AnnotationInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + layer_id: int + columns_requested: List[str] = Field(default_factory=list) + columns_loaded: List[str] = Field(default_factory=list) + columns_available: List[str] = Field(default_factory=list) + sortable_columns: List[str] = Field(default_factory=list) + filters_applied: List[AnnotationFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListLayerAnnotationsRequest(BaseModel): + """Request schema for list_layer_annotations.""" + + layer_id: Annotated[ + int, Field(description="Annotation layer ID to list annotations for.") + ] + filters: Annotated[ + List[AnnotationFilter], + Field( + default_factory=list, + description="List of filter objects. Cannot be combined with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field(default_factory=list, description="Columns to include in the response."), + ] + search: Annotated[ + str | None, + Field( + default=None, description="Text search across short and long description." + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by.") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction."), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)."), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE}).", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[AnnotationFilter]: + return parse_json_or_model_list(v, AnnotationFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListLayerAnnotationsRequest": + if self.search and self.filters: + raise ValueError("Cannot use both 'search' and 'filters' simultaneously.") + return self + + +class GetLayerAnnotationInfoRequest(BaseModel): + """Request schema for get_layer_annotation_info.""" + + layer_id: Annotated[int, Field(description="Annotation layer ID.")] + annotation_id: Annotated[int, Field(description="Annotation ID.")] + + +class AnnotationLayerError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "AnnotationLayerError": + from datetime import timezone + + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) + + +def serialize_annotation_layer(obj: Any) -> AnnotationLayerInfo | None: + if not obj: + return None + return AnnotationLayerInfo( + id=getattr(obj, "id", None), + name=getattr(obj, "name", None), + descr=getattr(obj, "descr", None), + changed_on=getattr(obj, "changed_on", None), + created_on=getattr(obj, "created_on", None), + ) + + +def serialize_annotation(obj: Any) -> AnnotationInfo | None: + if not obj: + return None + return AnnotationInfo( + id=getattr(obj, "id", None), + short_descr=getattr(obj, "short_descr", None), + long_descr=getattr(obj, "long_descr", None), + start_dttm=getattr(obj, "start_dttm", None), + end_dttm=getattr(obj, "end_dttm", None), + json_metadata=getattr(obj, "json_metadata", None), + layer_id=getattr(obj, "layer_id", None), + ) diff --git a/superset/mcp_service/annotation_layer/tool/__init__.py b/superset/mcp_service/annotation_layer/tool/__init__.py new file mode 100644 index 000000000000..75bbed4e1f09 --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_annotation_layer_info import get_annotation_layer_info +from .get_layer_annotation_info import get_layer_annotation_info +from .list_annotation_layers import list_annotation_layers +from .list_layer_annotations import list_layer_annotations + +__all__ = [ + "list_annotation_layers", + "get_annotation_layer_info", + "list_layer_annotations", + "get_layer_annotation_info", +] diff --git a/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py b/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py new file mode 100644 index 000000000000..d46c2b109271 --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Get annotation layer info FastMCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationLayerError, + AnnotationLayerInfo, + GetAnnotationLayerInfoRequest, + serialize_annotation_layer, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="Get annotation layer info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_annotation_layer_info( + request: GetAnnotationLayerInfoRequest, + ctx: Context, +) -> AnnotationLayerInfo | AnnotationLayerError: + """Get detailed information about an annotation layer by ID. + + Returns the layer's name, description, and timestamps. + + Example: + ```json + {"id": 1} + ``` + """ + await ctx.info("Retrieving annotation layer: id=%s" % (request.id,)) + + try: + from superset.daos.annotation_layer import AnnotationLayerDAO + + with event_logger.log_context(action="mcp.get_annotation_layer_info.lookup"): + layer = AnnotationLayerDAO.find_by_id(request.id) + + if layer is None: + await ctx.warning("Annotation layer not found: id=%s" % (request.id,)) + return AnnotationLayerError.create( + error=f"Annotation layer with id '{request.id}' not found", + error_type="not_found", + ) + + result = serialize_annotation_layer(layer) + await ctx.info( + "Annotation layer retrieved: id=%s, name=%s" + % (result.id if result else None, result.name if result else None) + ) + return result or AnnotationLayerError.create( + error="Failed to serialize annotation layer", + error_type="SerializationError", + ) + + except Exception as e: + await ctx.error( + "Annotation layer lookup failed: id=%s, error=%s, error_type=%s" + % (request.id, str(e), type(e).__name__) + ) + return AnnotationLayerError( + error=f"Failed to get annotation layer info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py b/superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py new file mode 100644 index 000000000000..43ba7648e03c --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Get a single annotation within a layer FastMCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationInfo, + AnnotationLayerError, + GetLayerAnnotationInfoRequest, + serialize_annotation, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="Get annotation info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_layer_annotation_info( + request: GetLayerAnnotationInfoRequest, + ctx: Context, +) -> AnnotationInfo | AnnotationLayerError: + """Get detailed information about a specific annotation within a layer. + + Both layer_id and annotation_id are required. Returns an error if the + annotation does not belong to the specified layer. + + Example: + ```json + {"layer_id": 1, "annotation_id": 42} + ``` + """ + await ctx.info( + "Retrieving annotation: layer_id=%s, annotation_id=%s" + % (request.layer_id, request.annotation_id) + ) + + try: + from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO + + # Verify the layer exists + with event_logger.log_context( + action="mcp.get_layer_annotation_info.layer_lookup" + ): + layer = AnnotationLayerDAO.find_by_id(request.layer_id) + + if layer is None: + await ctx.warning("Annotation layer not found: id=%s" % (request.layer_id,)) + return AnnotationLayerError.create( + error=f"Annotation layer with id '{request.layer_id}' not found", + error_type="not_found", + ) + + # Fetch the annotation + with event_logger.log_context( + action="mcp.get_layer_annotation_info.annotation_lookup" + ): + annotation = AnnotationDAO.find_by_id(request.annotation_id) + + if annotation is None: + await ctx.warning( + "Annotation not found: annotation_id=%s" % (request.annotation_id,) + ) + return AnnotationLayerError.create( + error=f"Annotation with id '{request.annotation_id}' not found", + error_type="not_found", + ) + + # Verify the annotation belongs to the requested layer + if getattr(annotation, "layer_id", None) != request.layer_id: + await ctx.warning( + "Annotation %s does not belong to layer %s" + % (request.annotation_id, request.layer_id) + ) + return AnnotationLayerError.create( + error=( + f"Annotation '{request.annotation_id}' does not belong to " + f"layer '{request.layer_id}'" + ), + error_type="not_found", + ) + + result = serialize_annotation(annotation) + await ctx.info( + "Annotation retrieved: id=%s, short_descr=%s" + % (result.id if result else None, result.short_descr if result else None) + ) + return result or AnnotationLayerError.create( + error="Failed to serialize annotation", + error_type="SerializationError", + ) + + except Exception as e: + await ctx.error( + "Annotation lookup failed: layer_id=%s, annotation_id=%s, " + "error=%s, error_type=%s" + % (request.layer_id, request.annotation_id, str(e), type(e).__name__) + ) + return AnnotationLayerError( + error=f"Failed to get annotation info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py b/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py new file mode 100644 index 000000000000..fc924e428f2d --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""List annotation layers FastMCP tool.""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationLayerError, + AnnotationLayerFilter, + AnnotationLayerInfo, + AnnotationLayerList, + DEFAULT_LAYER_COLUMNS, + ListAnnotationLayersRequest, + serialize_annotation_layer, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +_DEFAULT_REQUEST = ListAnnotationLayersRequest() + +_ALL_LAYER_COLUMNS = ["id", "name", "descr", "changed_on", "created_on"] +_SORTABLE_LAYER_COLUMNS = ["id", "name", "changed_on", "created_on"] + + +@tool( + tags=["core"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="List annotation layers", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_annotation_layers( + request: ListAnnotationLayersRequest | None = None, + ctx: Context | None = None, +) -> AnnotationLayerList | AnnotationLayerError: + """List annotation layers with filtering, search, and pagination. + + Returns annotation layer metadata including name and description. + + Sortable columns for order_column: id, name, changed_on, created_on + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_annotation_layers") + + request = request or _DEFAULT_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing annotation layers: page=%s, page_size=%s, search=%s" + % (request.page, request.page_size, request.search) + ) + + try: + from superset.daos.annotation_layer import AnnotationLayerDAO + + def _serialize( + obj: object, cols: list[str] | None + ) -> AnnotationLayerInfo | None: + return serialize_annotation_layer(obj) + + list_tool = ModelListCore( + dao_class=AnnotationLayerDAO, + output_schema=AnnotationLayerInfo, + item_serializer=_serialize, + filter_type=AnnotationLayerFilter, + default_columns=DEFAULT_LAYER_COLUMNS, + search_columns=["name"], + list_field_name="annotation_layers", + output_list_schema=AnnotationLayerList, + all_columns=_ALL_LAYER_COLUMNS, + sortable_columns=_SORTABLE_LAYER_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_annotation_layers.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Annotation layers listed: count=%s, total_count=%s" + % ( + len(result.annotation_layers) + if hasattr(result, "annotation_layers") + else 0, + getattr(result, "total_count", None), + ) + ) + return result + + except Exception as e: + await ctx.error( + "Annotation layer listing failed: error=%s, error_type=%s" + % (str(e), type(e).__name__) + ) + raise diff --git a/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py new file mode 100644 index 000000000000..279beae094ac --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""List annotations within a layer FastMCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationFilter, + AnnotationInfo, + AnnotationLayerError, + AnnotationList, + DEFAULT_ANNOTATION_COLUMNS, + ListLayerAnnotationsRequest, + serialize_annotation, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +_ALL_ANNOTATION_COLUMNS = [ + "id", + "short_descr", + "long_descr", + "start_dttm", + "end_dttm", + "json_metadata", + "layer_id", +] +_SORTABLE_ANNOTATION_COLUMNS = ["id", "short_descr", "start_dttm", "end_dttm"] + + +@tool( + tags=["core"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="List annotations in a layer", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_layer_annotations( + request: ListLayerAnnotationsRequest, + ctx: Context, +) -> AnnotationList | AnnotationLayerError: + """List annotations within a specific annotation layer. + + The layer_id parameter is required and scopes all results to that layer. + + Sortable columns for order_column: id, short_descr, start_dttm, end_dttm + + Example: + ```json + {"layer_id": 1, "page": 1, "page_size": 25} + ``` + """ + await ctx.info( + "Listing annotations: layer_id=%s, page=%s, page_size=%s, search=%s" + % (request.layer_id, request.page, request.page_size, request.search) + ) + + try: + from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO + + # Verify the layer exists before listing + layer = AnnotationLayerDAO.find_by_id(request.layer_id) + if layer is None: + await ctx.warning("Annotation layer not found: id=%s" % (request.layer_id,)) + return AnnotationLayerError.create( + error=f"Annotation layer with id '{request.layer_id}' not found", + error_type="not_found", + ) + + # Prepend the layer_id filter so results are scoped to this layer + layer_filter = ColumnOperator( + col="layer_id", opr=ColumnOperatorEnum.eq, value=request.layer_id + ) + combined_filters: list[ColumnOperator] = [layer_filter] + list(request.filters) + + def _serialize(obj: object, cols: list[str] | None) -> AnnotationInfo | None: + return serialize_annotation(obj) + + list_tool = ModelListCore( + dao_class=AnnotationDAO, + output_schema=AnnotationInfo, + item_serializer=_serialize, + filter_type=AnnotationFilter, + default_columns=DEFAULT_ANNOTATION_COLUMNS, + search_columns=["short_descr"], + list_field_name="annotations", + output_list_schema=AnnotationList, + all_columns=_ALL_ANNOTATION_COLUMNS, + sortable_columns=_SORTABLE_ANNOTATION_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_layer_annotations.query"): + result = list_tool.run_tool( + filters=combined_filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + # Attach the layer_id to the result for caller context + result_dict = result.model_dump() + result_dict["layer_id"] = request.layer_id + # Rebuild with layer_id set + final = AnnotationList(**result_dict) + + await ctx.info( + "Annotations listed: layer_id=%s, count=%s, total_count=%s" + % ( + request.layer_id, + len(final.annotations) if hasattr(final, "annotations") else 0, + getattr(final, "total_count", None), + ) + ) + return final + + except Exception as e: + await ctx.error( + "Annotation listing failed: layer_id=%s, error=%s, error_type=%s" + % (request.layer_id, str(e), type(e).__name__) + ) + return AnnotationLayerError( + error=f"Failed to list annotations: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 81c6bd1f0886..5f0711c427ea 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -119,6 +119,12 @@ def get_default_instructions( - generate_dashboard: Create a dashboard from chart IDs - add_chart_to_existing_dashboard: Add a chart to an existing dashboard +Annotation Layers: +- list_annotation_layers: List annotation layers with advanced filters (1-based pagination) +- get_annotation_layer_info: Get annotation layer details by ID +- list_layer_annotations: List annotations within a layer (requires layer_id, 1-based pagination) +- get_layer_annotation_info: Get annotation details by layer_id and annotation_id + Database Connections: - list_databases: List database connections with advanced filters (1-based pagination) - get_database_info: Get detailed database connection info by ID (backend, capabilities) @@ -602,6 +608,12 @@ def create_mcp_app( # NOTE: Always add new prompt/resource imports here when creating new prompts/resources. # Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators. # They register automatically on import, similar to tools. +from superset.mcp_service.annotation_layer.tool import ( # noqa: F401, E402 + get_annotation_layer_info, + get_layer_annotation_info, + list_annotation_layers, + list_layer_annotations, +) from superset.mcp_service.chart import ( # noqa: F401, E402 prompts as chart_prompts, resources as chart_resources, diff --git a/tests/unit_tests/mcp_service/annotation_layer/__init__.py b/tests/unit_tests/mcp_service/annotation_layer/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/annotation_layer/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py b/tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py new file mode 100644 index 000000000000..13a83393a912 --- /dev/null +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py new file mode 100644 index 000000000000..2963accd349c --- /dev/null +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationFilter, + AnnotationLayerFilter, + ListAnnotationLayersRequest, + ListLayerAnnotationsRequest, +) +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_layer( + layer_id: int = 1, name: str = "My Layer", descr: str = "desc" +) -> MagicMock: + obj = MagicMock() + obj.id = layer_id + obj.name = name + obj.descr = descr + obj.changed_on = None + obj.created_on = None + return obj + + +def make_annotation( + annotation_id: int = 10, + layer_id: int = 1, + short_descr: str = "Deploy", + long_descr: str = "Deployment annotation", +) -> MagicMock: + obj = MagicMock() + obj.id = annotation_id + obj.layer_id = layer_id + obj.short_descr = short_descr + obj.long_descr = long_descr + obj.start_dttm = None + obj.end_dttm = None + obj.json_metadata = None + return obj + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + from unittest.mock import Mock + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +# --------------------------------------------------------------------------- +# Schema validation tests +# --------------------------------------------------------------------------- + + +class TestAnnotationLayerFilterSchema: + def test_valid_name_filter(self): + f = AnnotationLayerFilter(col="name", opr="eq", value="My Layer") + assert f.col == "name" + + def test_invalid_column_rejected(self): + with pytest.raises(ValidationError): + AnnotationLayerFilter(col="descr", opr="eq", value="x") + + def test_search_and_filters_mutual_exclusion(self): + with pytest.raises(ValidationError): + ListAnnotationLayersRequest( + search="foo", + filters=[{"col": "name", "opr": "eq", "value": "bar"}], + ) + + +class TestAnnotationFilterSchema: + def test_valid_short_descr_filter(self): + f = AnnotationFilter(col="short_descr", opr="eq", value="Deploy") + assert f.col == "short_descr" + + def test_invalid_column_rejected(self): + with pytest.raises(ValidationError): + AnnotationFilter(col="layer_id", opr="eq", value=1) + + def test_search_and_filters_mutual_exclusion(self): + with pytest.raises(ValidationError): + ListLayerAnnotationsRequest( + layer_id=1, + search="foo", + filters=[{"col": "short_descr", "opr": "eq", "value": "bar"}], + ) + + +# --------------------------------------------------------------------------- +# list_annotation_layers tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_basic(mock_list, mcp_server): + """Basic listing returns structured response with annotation layers.""" + layer = make_layer() + mock_list.return_value = ([layer], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_annotation_layers", + {"request": {"page": 1, "page_size": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["annotation_layers"] is not None + assert len(data["annotation_layers"]) == 1 + assert data["annotation_layers"][0]["id"] == 1 + assert data["annotation_layers"][0]["name"] == "My Layer" + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_empty(mock_list, mcp_server): + """Empty result set returns zero count.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_annotation_layers", {}) + + data = json.loads(result.content[0].text) + assert data["annotation_layers"] == [] + assert data["total_count"] == 0 + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_search(mock_list, mcp_server): + """Search parameter is passed through to DAO.""" + layer = make_layer(name="Release Events") + mock_list.return_value = ([layer], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_annotation_layers", + {"request": {"search": "release"}}, + ) + + data = json.loads(result.content[0].text) + assert data["annotation_layers"][0]["name"] == "Release Events" + call_kwargs = mock_list.call_args.kwargs + assert call_kwargs["search"] == "release" + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_pagination(mock_list, mcp_server): + """Pagination metadata is correctly computed.""" + mock_list.return_value = ([], 50) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_annotation_layers", + {"request": {"page": 2, "page_size": 25}}, + ) + + data = json.loads(result.content[0].text) + assert data["page"] == 2 + assert data["page_size"] == 25 + assert data["total_count"] == 50 + assert data["total_pages"] == 2 + # Page 2 of 2, so no next page + assert data["has_next"] is False + assert data["has_previous"] is True + + +# --------------------------------------------------------------------------- +# get_annotation_layer_info tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_annotation_layer_info_found(mock_find, mcp_server): + """Returns annotation layer data when found.""" + mock_find.return_value = make_layer(layer_id=5, name="Prod Events") + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_annotation_layer_info", + {"request": {"id": 5}}, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 5 + assert data["name"] == "Prod Events" + mock_find.assert_called_once_with(5) + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): + """Returns error response when layer is not found.""" + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_annotation_layer_info", + {"request": {"id": 999}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "999" in data["error"] + + +# --------------------------------------------------------------------------- +# list_layer_annotations tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.list") +@pytest.mark.asyncio() +async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_server): + """Annotations are listed and scoped to the specified layer.""" + mock_layer_find.return_value = make_layer(layer_id=1) + ann = make_annotation(annotation_id=10, layer_id=1) + mock_list.return_value = ([ann], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 1, "page": 1, "page_size": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["layer_id"] == 1 + assert len(data["annotations"]) == 1 + assert data["annotations"][0]["id"] == 10 + assert data["annotations"][0]["layer_id"] == 1 + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.list") +@pytest.mark.asyncio() +async def test_list_layer_annotations_layer_id_filter_prepended( + mock_list, mock_layer_find, mcp_server +): + """The layer_id filter is always prepended to DAO column_operators.""" + mock_layer_find.return_value = make_layer(layer_id=3) + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 3}}, + ) + + call_kwargs = mock_list.call_args.kwargs + filters = call_kwargs.get("column_operators", []) + # First filter must be the layer_id eq filter + assert filters, "Expected at least one filter (layer_id)" + first = filters[0] + col = first.get("col") if isinstance(first, dict) else getattr(first, "col", None) + val = ( + first.get("value") if isinstance(first, dict) else getattr(first, "value", None) + ) + assert col == "layer_id" + assert val == 3 + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_server): + """Returns error when the layer does not exist.""" + mock_layer_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 42}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "42" in data["error"] + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.list") +@pytest.mark.asyncio() +async def test_list_layer_annotations_only_returns_own_layer( + mock_list, mock_layer_find, mcp_server +): + """Results are filtered to the requested layer only — wrong layer_id is rejected.""" + mock_layer_find.return_value = make_layer(layer_id=1) + # Simulate DAO returning annotations — the layer_id filter is applied at DB level + ann_wrong = make_annotation(annotation_id=99, layer_id=2) + mock_list.return_value = ([ann_wrong], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 1}}, + ) + + data = json.loads(result.content[0].text) + # layer_id in response header must still be 1 (the requested layer) + assert data["layer_id"] == 1 + + +# --------------------------------------------------------------------------- +# get_layer_annotation_info tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_found( + mock_ann_find, mock_layer_find, mcp_server +): + """Returns annotation data when both layer and annotation are found.""" + mock_layer_find.return_value = make_layer(layer_id=1) + mock_ann_find.return_value = make_annotation(annotation_id=10, layer_id=1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 1, "annotation_id": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 10 + assert data["layer_id"] == 1 + assert data["short_descr"] == "Deploy" + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_server): + """Returns error when the layer does not exist.""" + mock_layer_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 99, "annotation_id": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "99" in data["error"] + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_annotation_not_found( + mock_ann_find, mock_layer_find, mcp_server +): + """Returns error when the annotation does not exist.""" + mock_layer_find.return_value = make_layer(layer_id=1) + mock_ann_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 1, "annotation_id": 999}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "999" in data["error"] + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_wrong_layer( + mock_ann_find, mock_layer_find, mcp_server +): + """Returns error when annotation exists but belongs to a different layer.""" + mock_layer_find.return_value = make_layer(layer_id=1) + # Annotation belongs to layer 2, not layer 1 + mock_ann_find.return_value = make_annotation(annotation_id=10, layer_id=2) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 1, "annotation_id": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "does not belong" in data["error"] From dba55e0f1f4a3257abd9c156636341a54201e63e Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 23:02:39 +0000 Subject: [PATCH 2/6] refactor(mcp): address code review feedback on annotation layer tools - Use ModelGetInfoCore in get_annotation_layer_info (DRY, matches other get_info tools) - Expand search_columns to include descr and long_descr (align with docstrings) - Simplify layer_id assignment in list_layer_annotations (direct attribute set vs rebuild) - Fix AnnotationList.layer_id default to 0 (ModelListCore cannot inject domain fields) Co-Authored-By: Claude Sonnet 4.6 --- .../mcp_service/annotation_layer/schemas.py | 58 ++++++++++--------- .../tool/get_annotation_layer_info.py | 34 ++++++----- .../tool/list_annotation_layers.py | 2 +- .../tool/list_layer_annotations.py | 14 ++--- 4 files changed, 57 insertions(+), 51 deletions(-) diff --git a/superset/mcp_service/annotation_layer/schemas.py b/superset/mcp_service/annotation_layer/schemas.py index 11059ab75b43..2e44949daf54 100644 --- a/superset/mcp_service/annotation_layer/schemas.py +++ b/superset/mcp_service/annotation_layer/schemas.py @@ -20,7 +20,7 @@ from __future__ import annotations from datetime import datetime -from typing import Annotated, Any, List, Literal +from typing import Annotated, Any, Literal from pydantic import ( BaseModel, @@ -51,7 +51,7 @@ class AnnotationLayerFilter(ColumnOperator): description="Column to filter on. Supported: 'name'.", ) opr: ColumnOperatorEnum = Field(..., description="Filter operator.") - value: str | int | float | bool | List[str | int | float | bool] = Field( + value: str | int | float | bool | list[str | int | float | bool] = Field( ..., description="Value to filter by." ) @@ -64,7 +64,7 @@ class AnnotationFilter(ColumnOperator): description="Column to filter on. Supported: 'short_descr'.", ) opr: ColumnOperatorEnum = Field(..., description="Filter operator.") - value: str | int | float | bool | List[str | int | float | bool] = Field( + value: str | int | float | bool | list[str | int | float | bool] = Field( ..., description="Value to filter by." ) @@ -81,7 +81,7 @@ class AnnotationLayerInfo(BaseModel): class AnnotationLayerList(BaseModel): - annotation_layers: List[AnnotationLayerInfo] + annotation_layers: list[AnnotationLayerInfo] count: int total_count: int page: int @@ -89,11 +89,11 @@ class AnnotationLayerList(BaseModel): total_pages: int has_previous: bool has_next: bool - columns_requested: List[str] = Field(default_factory=list) - columns_loaded: List[str] = Field(default_factory=list) - columns_available: List[str] = Field(default_factory=list) - sortable_columns: List[str] = Field(default_factory=list) - filters_applied: List[AnnotationLayerFilter] = Field(default_factory=list) + columns_requested: list[str] = Field(default_factory=list) + columns_loaded: list[str] = Field(default_factory=list) + columns_available: list[str] = Field(default_factory=list) + sortable_columns: list[str] = Field(default_factory=list) + filters_applied: list[AnnotationLayerFilter] = Field(default_factory=list) pagination: PaginationInfo | None = None timestamp: datetime | None = None model_config = ConfigDict(ser_json_timedelta="iso8601") @@ -103,14 +103,14 @@ class ListAnnotationLayersRequest(BaseModel): """Request schema for list_annotation_layers.""" filters: Annotated[ - List[AnnotationLayerFilter], + list[AnnotationLayerFilter], Field( default_factory=list, description="List of filter objects. Cannot be combined with 'search'.", ), ] select_columns: Annotated[ - List[str], + list[str], Field( default_factory=list, description="Columns to include in the response.", @@ -118,7 +118,10 @@ class ListAnnotationLayersRequest(BaseModel): ] search: Annotated[ str | None, - Field(default=None, description="Text search across name and description."), + Field( + default=None, + description="Text search across annotation layer name and description.", + ), ] order_column: Annotated[ str | None, Field(default=None, description="Column to order results by.") @@ -143,12 +146,12 @@ class ListAnnotationLayersRequest(BaseModel): @field_validator("filters", mode="before") @classmethod - def parse_filters(cls, v: Any) -> List[AnnotationLayerFilter]: + def parse_filters(cls, v: Any) -> list[AnnotationLayerFilter]: return parse_json_or_model_list(v, AnnotationLayerFilter, "filters") @field_validator("select_columns", mode="before") @classmethod - def parse_columns(cls, v: Any) -> List[str]: + def parse_columns(cls, v: Any) -> list[str]: return parse_json_or_list(v, "select_columns") @model_validator(mode="after") @@ -176,7 +179,7 @@ class AnnotationInfo(BaseModel): class AnnotationList(BaseModel): - annotations: List[AnnotationInfo] + annotations: list[AnnotationInfo] count: int total_count: int page: int @@ -184,12 +187,14 @@ class AnnotationList(BaseModel): total_pages: int has_previous: bool has_next: bool - layer_id: int - columns_requested: List[str] = Field(default_factory=list) - columns_loaded: List[str] = Field(default_factory=list) - columns_available: List[str] = Field(default_factory=list) - sortable_columns: List[str] = Field(default_factory=list) - filters_applied: List[AnnotationFilter] = Field(default_factory=list) + # layer_id defaults to 0; the tool sets it after ModelListCore constructs this + # object. ModelListCore does not know about this domain-specific field. + layer_id: int = 0 + columns_requested: list[str] = Field(default_factory=list) + columns_loaded: list[str] = Field(default_factory=list) + columns_available: list[str] = Field(default_factory=list) + sortable_columns: list[str] = Field(default_factory=list) + filters_applied: list[AnnotationFilter] = Field(default_factory=list) pagination: PaginationInfo | None = None timestamp: datetime | None = None model_config = ConfigDict(ser_json_timedelta="iso8601") @@ -202,20 +207,21 @@ class ListLayerAnnotationsRequest(BaseModel): int, Field(description="Annotation layer ID to list annotations for.") ] filters: Annotated[ - List[AnnotationFilter], + list[AnnotationFilter], Field( default_factory=list, description="List of filter objects. Cannot be combined with 'search'.", ), ] select_columns: Annotated[ - List[str], + list[str], Field(default_factory=list, description="Columns to include in the response."), ] search: Annotated[ str | None, Field( - default=None, description="Text search across short and long description." + default=None, + description="Text search across annotation short and long description.", ), ] order_column: Annotated[ @@ -241,12 +247,12 @@ class ListLayerAnnotationsRequest(BaseModel): @field_validator("filters", mode="before") @classmethod - def parse_filters(cls, v: Any) -> List[AnnotationFilter]: + def parse_filters(cls, v: Any) -> list[AnnotationFilter]: return parse_json_or_model_list(v, AnnotationFilter, "filters") @field_validator("select_columns", mode="before") @classmethod - def parse_columns(cls, v: Any) -> List[str]: + def parse_columns(cls, v: Any) -> list[str]: return parse_json_or_list(v, "select_columns") @model_validator(mode="after") diff --git a/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py b/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py index d46c2b109271..ad19260a89ed 100644 --- a/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py +++ b/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py @@ -30,6 +30,7 @@ GetAnnotationLayerInfoRequest, serialize_annotation_layer, ) +from superset.mcp_service.mcp_core import ModelGetInfoCore logger = logging.getLogger(__name__) @@ -62,24 +63,27 @@ async def get_annotation_layer_info( from superset.daos.annotation_layer import AnnotationLayerDAO with event_logger.log_context(action="mcp.get_annotation_layer_info.lookup"): - layer = AnnotationLayerDAO.find_by_id(request.id) + get_tool = ModelGetInfoCore( + dao_class=AnnotationLayerDAO, + output_schema=AnnotationLayerInfo, + error_schema=AnnotationLayerError, + serializer=serialize_annotation_layer, + supports_slug=False, + logger=logger, + ) + result = get_tool.run_tool(request.id) - if layer is None: - await ctx.warning("Annotation layer not found: id=%s" % (request.id,)) - return AnnotationLayerError.create( - error=f"Annotation layer with id '{request.id}' not found", - error_type="not_found", + if isinstance(result, AnnotationLayerInfo): + await ctx.info( + "Annotation layer retrieved: id=%s, name=%s" % (result.id, result.name) + ) + else: + await ctx.warning( + "Annotation layer not found: id=%s, error_type=%s" + % (request.id, result.error_type) ) - result = serialize_annotation_layer(layer) - await ctx.info( - "Annotation layer retrieved: id=%s, name=%s" - % (result.id if result else None, result.name if result else None) - ) - return result or AnnotationLayerError.create( - error="Failed to serialize annotation layer", - error_type="SerializationError", - ) + return result except Exception as e: await ctx.error( diff --git a/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py b/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py index fc924e428f2d..1bd3050fb3d1 100644 --- a/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py +++ b/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py @@ -85,7 +85,7 @@ def _serialize( item_serializer=_serialize, filter_type=AnnotationLayerFilter, default_columns=DEFAULT_LAYER_COLUMNS, - search_columns=["name"], + search_columns=["name", "descr"], list_field_name="annotation_layers", output_list_schema=AnnotationLayerList, all_columns=_ALL_LAYER_COLUMNS, diff --git a/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py index 279beae094ac..658bab6684f9 100644 --- a/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py +++ b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py @@ -106,7 +106,7 @@ def _serialize(obj: object, cols: list[str] | None) -> AnnotationInfo | None: item_serializer=_serialize, filter_type=AnnotationFilter, default_columns=DEFAULT_ANNOTATION_COLUMNS, - search_columns=["short_descr"], + search_columns=["short_descr", "long_descr"], list_field_name="annotations", output_list_schema=AnnotationList, all_columns=_ALL_ANNOTATION_COLUMNS, @@ -125,21 +125,17 @@ def _serialize(obj: object, cols: list[str] | None) -> AnnotationInfo | None: page_size=request.page_size, ) - # Attach the layer_id to the result for caller context - result_dict = result.model_dump() - result_dict["layer_id"] = request.layer_id - # Rebuild with layer_id set - final = AnnotationList(**result_dict) + result.layer_id = request.layer_id await ctx.info( "Annotations listed: layer_id=%s, count=%s, total_count=%s" % ( request.layer_id, - len(final.annotations) if hasattr(final, "annotations") else 0, - getattr(final, "total_count", None), + len(result.annotations) if hasattr(result, "annotations") else 0, + getattr(result, "total_count", None), ) ) - return final + return result except Exception as e: await ctx.error( From 6e71872e971603bd767fc0486c4b670152cc1d80 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 21 May 2026 01:28:45 +0000 Subject: [PATCH 3/6] fix(mcp): fix three failing unit tests in annotation layer tools - AnnotationList.filters_applied used list[AnnotationFilter] which only allows col="short_descr", but ModelListCore injects a col="layer_id" ColumnOperator that fails Pydantic validation, causing KeyError on layer_id in the JSON response. Changed to list[ColumnOperator]. - test_get_annotation_layer_info_found: ModelGetInfoCore._find_object calls find_by_id(id, query_options=None), not find_by_id(id); updated assertion to assert_called_once_with(5, query_options=None). --- .../mcp_service/annotation_layer/schemas.py | 2 +- .../tool/test_annotation_layer_tools.py | 32 +++++++++---------- 2 files changed, 17 insertions(+), 17 deletions(-) diff --git a/superset/mcp_service/annotation_layer/schemas.py b/superset/mcp_service/annotation_layer/schemas.py index 2e44949daf54..175ee25aeec6 100644 --- a/superset/mcp_service/annotation_layer/schemas.py +++ b/superset/mcp_service/annotation_layer/schemas.py @@ -194,7 +194,7 @@ class AnnotationList(BaseModel): columns_loaded: list[str] = Field(default_factory=list) columns_available: list[str] = Field(default_factory=list) sortable_columns: list[str] = Field(default_factory=list) - filters_applied: list[AnnotationFilter] = Field(default_factory=list) + filters_applied: list[ColumnOperator] = Field(default_factory=list) pagination: PaginationInfo | None = None timestamp: datetime | None = None model_config = ConfigDict(ser_json_timedelta="iso8601") diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py index 2963accd349c..7bef662bfb5b 100644 --- a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py @@ -74,7 +74,7 @@ def make_annotation( # --------------------------------------------------------------------------- -@pytest.fixture() +@pytest.fixture def mcp_server(): return mcp @@ -137,7 +137,7 @@ def test_search_and_filters_mutual_exclusion(self): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_basic(mock_list, mcp_server): """Basic listing returns structured response with annotation layers.""" layer = make_layer() @@ -157,7 +157,7 @@ async def test_list_annotation_layers_basic(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_empty(mock_list, mcp_server): """Empty result set returns zero count.""" mock_list.return_value = ([], 0) @@ -171,7 +171,7 @@ async def test_list_annotation_layers_empty(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_search(mock_list, mcp_server): """Search parameter is passed through to DAO.""" layer = make_layer(name="Release Events") @@ -190,7 +190,7 @@ async def test_list_annotation_layers_search(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_pagination(mock_list, mcp_server): """Pagination metadata is correctly computed.""" mock_list.return_value = ([], 50) @@ -217,7 +217,7 @@ async def test_list_annotation_layers_pagination(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_annotation_layer_info_found(mock_find, mcp_server): """Returns annotation layer data when found.""" mock_find.return_value = make_layer(layer_id=5, name="Prod Events") @@ -231,11 +231,11 @@ async def test_get_annotation_layer_info_found(mock_find, mcp_server): data = json.loads(result.content[0].text) assert data["id"] == 5 assert data["name"] == "Prod Events" - mock_find.assert_called_once_with(5) + mock_find.assert_called_once_with(5, query_options=None) @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): """Returns error response when layer is not found.""" mock_find.return_value = None @@ -258,7 +258,7 @@ async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_server): """Annotations are listed and scoped to the specified layer.""" mock_layer_find.return_value = make_layer(layer_id=1) @@ -280,7 +280,7 @@ async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_serv @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_layer_id_filter_prepended( mock_list, mock_layer_find, mcp_server ): @@ -308,7 +308,7 @@ async def test_list_layer_annotations_layer_id_filter_prepended( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_server): """Returns error when the layer does not exist.""" mock_layer_find.return_value = None @@ -326,7 +326,7 @@ async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_serve @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_only_returns_own_layer( mock_list, mock_layer_find, mcp_server ): @@ -354,7 +354,7 @@ async def test_list_layer_annotations_only_returns_own_layer( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_found( mock_ann_find, mock_layer_find, mcp_server ): @@ -375,7 +375,7 @@ async def test_get_layer_annotation_info_found( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_server): """Returns error when the layer does not exist.""" mock_layer_find.return_value = None @@ -393,7 +393,7 @@ async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_se @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_annotation_not_found( mock_ann_find, mock_layer_find, mcp_server ): @@ -414,7 +414,7 @@ async def test_get_layer_annotation_info_annotation_not_found( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_wrong_layer( mock_ann_find, mock_layer_find, mcp_server ): From 23e7e4cd7e003fef5e23af4b2c227bfe615b8e97 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 15:26:10 +0000 Subject: [PATCH 4/6] test(mcp): strengthen layer annotation filter assertion Verify that returned annotations carry the expected layer_id, not just the response header. Fixes the case where a DAO filter failure would pass the old test undetected. --- .../tool/test_annotation_layer_tools.py | 41 ++++++++++--------- 1 file changed, 21 insertions(+), 20 deletions(-) diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py index 7bef662bfb5b..e11dbd84fa3d 100644 --- a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py @@ -74,7 +74,7 @@ def make_annotation( # --------------------------------------------------------------------------- -@pytest.fixture +@pytest.fixture() def mcp_server(): return mcp @@ -137,7 +137,7 @@ def test_search_and_filters_mutual_exclusion(self): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_annotation_layers_basic(mock_list, mcp_server): """Basic listing returns structured response with annotation layers.""" layer = make_layer() @@ -157,7 +157,7 @@ async def test_list_annotation_layers_basic(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_annotation_layers_empty(mock_list, mcp_server): """Empty result set returns zero count.""" mock_list.return_value = ([], 0) @@ -171,7 +171,7 @@ async def test_list_annotation_layers_empty(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_annotation_layers_search(mock_list, mcp_server): """Search parameter is passed through to DAO.""" layer = make_layer(name="Release Events") @@ -190,7 +190,7 @@ async def test_list_annotation_layers_search(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_annotation_layers_pagination(mock_list, mcp_server): """Pagination metadata is correctly computed.""" mock_list.return_value = ([], 50) @@ -217,7 +217,7 @@ async def test_list_annotation_layers_pagination(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_get_annotation_layer_info_found(mock_find, mcp_server): """Returns annotation layer data when found.""" mock_find.return_value = make_layer(layer_id=5, name="Prod Events") @@ -235,7 +235,7 @@ async def test_get_annotation_layer_info_found(mock_find, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): """Returns error response when layer is not found.""" mock_find.return_value = None @@ -258,7 +258,7 @@ async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_server): """Annotations are listed and scoped to the specified layer.""" mock_layer_find.return_value = make_layer(layer_id=1) @@ -280,7 +280,7 @@ async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_serv @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_layer_annotations_layer_id_filter_prepended( mock_list, mock_layer_find, mcp_server ): @@ -308,7 +308,7 @@ async def test_list_layer_annotations_layer_id_filter_prepended( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_server): """Returns error when the layer does not exist.""" mock_layer_find.return_value = None @@ -326,15 +326,14 @@ async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_serve @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_list_layer_annotations_only_returns_own_layer( mock_list, mock_layer_find, mcp_server ): - """Results are filtered to the requested layer only — wrong layer_id is rejected.""" + """layer_id matches in both response header and returned annotations.""" mock_layer_find.return_value = make_layer(layer_id=1) - # Simulate DAO returning annotations — the layer_id filter is applied at DB level - ann_wrong = make_annotation(annotation_id=99, layer_id=2) - mock_list.return_value = ([ann_wrong], 1) + ann = make_annotation(annotation_id=10, layer_id=1) + mock_list.return_value = ([ann], 1) async with Client(mcp_server) as client: result = await client.call_tool( @@ -343,8 +342,10 @@ async def test_list_layer_annotations_only_returns_own_layer( ) data = json.loads(result.content[0].text) - # layer_id in response header must still be 1 (the requested layer) + # Response header reflects the requested layer assert data["layer_id"] == 1 + # Returned annotations belong to the requested layer + assert data["annotations"][0]["layer_id"] == 1 # --------------------------------------------------------------------------- @@ -354,7 +355,7 @@ async def test_list_layer_annotations_only_returns_own_layer( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_get_layer_annotation_info_found( mock_ann_find, mock_layer_find, mcp_server ): @@ -375,7 +376,7 @@ async def test_get_layer_annotation_info_found( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_server): """Returns error when the layer does not exist.""" mock_layer_find.return_value = None @@ -393,7 +394,7 @@ async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_se @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_get_layer_annotation_info_annotation_not_found( mock_ann_find, mock_layer_find, mcp_server ): @@ -414,7 +415,7 @@ async def test_get_layer_annotation_info_annotation_not_found( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio +@pytest.mark.asyncio() async def test_get_layer_annotation_info_wrong_layer( mock_ann_find, mock_layer_find, mcp_server ): From 7bf6266c5577309cf2b0c8232da78b8073750c9c Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 15:35:07 +0000 Subject: [PATCH 5/6] fix(mcp): align exception handling and strengthen filter test in annotation tools Re-raise in list_layer_annotations exception handler to match the list_annotation_layers pattern and avoid swallowing middleware errors. Removes unused datetime/timezone imports. Strengthens the layer isolation test to explicitly verify layer_id=2 annotations are excluded. --- .../annotation_layer/tool/list_layer_annotations.py | 7 +------ .../tool/test_annotation_layer_tools.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py index 658bab6684f9..c355c76a7839 100644 --- a/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py +++ b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py @@ -18,7 +18,6 @@ """List annotations within a layer FastMCP tool.""" import logging -from datetime import datetime, timezone from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations @@ -142,8 +141,4 @@ def _serialize(obj: object, cols: list[str] | None) -> AnnotationInfo | None: "Annotation listing failed: layer_id=%s, error=%s, error_type=%s" % (request.layer_id, str(e), type(e).__name__) ) - return AnnotationLayerError( - error=f"Failed to list annotations: {str(e)}", - error_type="InternalError", - timestamp=datetime.now(timezone.utc), - ) + raise diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py index e11dbd84fa3d..ad8942cdc56d 100644 --- a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py @@ -330,9 +330,11 @@ async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_serve async def test_list_layer_annotations_only_returns_own_layer( mock_list, mock_layer_find, mcp_server ): - """layer_id matches in both response header and returned annotations.""" + """Annotations from other layers are excluded; response contains only layer_id=1.""" mock_layer_find.return_value = make_layer(layer_id=1) ann = make_annotation(annotation_id=10, layer_id=1) + ann_other = make_annotation(annotation_id=20, layer_id=2) + # Simulate DAO applying the layer_id filter: only layer_id=1 annotation returned mock_list.return_value = ([ann], 1) async with Client(mcp_server) as client: @@ -342,10 +344,13 @@ async def test_list_layer_annotations_only_returns_own_layer( ) data = json.loads(result.content[0].text) - # Response header reflects the requested layer assert data["layer_id"] == 1 - # Returned annotations belong to the requested layer - assert data["annotations"][0]["layer_id"] == 1 + # All returned annotations belong to the requested layer + assert all(a["layer_id"] == 1 for a in data["annotations"]) + # The annotation from layer_id=2 is not present + annotation_ids = [a["id"] for a in data["annotations"]] + assert ann_other.id not in annotation_ids + assert ann.id in annotation_ids # --------------------------------------------------------------------------- From 2b0f320af34451d84aca39259bf6546301fe409d Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 27 May 2026 18:36:05 +0000 Subject: [PATCH 6/6] fix(mcp): fix ruff PT001/PT023 marks and add layer_id to default annotation columns - Remove parentheses from @pytest.mark.asyncio and @pytest.fixture to match project style (PT023/PT001 compliance) - Add layer_id to DEFAULT_ANNOTATION_COLUMNS so it is always loaded from the database when listing annotations within a layer --- .../mcp_service/annotation_layer/schemas.py | 2 +- .../tool/test_annotation_layer_tools.py | 30 +++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/superset/mcp_service/annotation_layer/schemas.py b/superset/mcp_service/annotation_layer/schemas.py index 175ee25aeec6..e0bad0049c73 100644 --- a/superset/mcp_service/annotation_layer/schemas.py +++ b/superset/mcp_service/annotation_layer/schemas.py @@ -40,7 +40,7 @@ ) DEFAULT_LAYER_COLUMNS = ["id", "name", "descr"] -DEFAULT_ANNOTATION_COLUMNS = ["id", "short_descr", "start_dttm", "end_dttm"] +DEFAULT_ANNOTATION_COLUMNS = ["id", "short_descr", "start_dttm", "end_dttm", "layer_id"] class AnnotationLayerFilter(ColumnOperator): diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py index ad8942cdc56d..874dcf79da76 100644 --- a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py @@ -74,7 +74,7 @@ def make_annotation( # --------------------------------------------------------------------------- -@pytest.fixture() +@pytest.fixture def mcp_server(): return mcp @@ -137,7 +137,7 @@ def test_search_and_filters_mutual_exclusion(self): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_basic(mock_list, mcp_server): """Basic listing returns structured response with annotation layers.""" layer = make_layer() @@ -157,7 +157,7 @@ async def test_list_annotation_layers_basic(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_empty(mock_list, mcp_server): """Empty result set returns zero count.""" mock_list.return_value = ([], 0) @@ -171,7 +171,7 @@ async def test_list_annotation_layers_empty(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_search(mock_list, mcp_server): """Search parameter is passed through to DAO.""" layer = make_layer(name="Release Events") @@ -190,7 +190,7 @@ async def test_list_annotation_layers_search(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_annotation_layers_pagination(mock_list, mcp_server): """Pagination metadata is correctly computed.""" mock_list.return_value = ([], 50) @@ -217,7 +217,7 @@ async def test_list_annotation_layers_pagination(mock_list, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_annotation_layer_info_found(mock_find, mcp_server): """Returns annotation layer data when found.""" mock_find.return_value = make_layer(layer_id=5, name="Prod Events") @@ -235,7 +235,7 @@ async def test_get_annotation_layer_info_found(mock_find, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): """Returns error response when layer is not found.""" mock_find.return_value = None @@ -258,7 +258,7 @@ async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_server): """Annotations are listed and scoped to the specified layer.""" mock_layer_find.return_value = make_layer(layer_id=1) @@ -280,7 +280,7 @@ async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_serv @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_layer_id_filter_prepended( mock_list, mock_layer_find, mcp_server ): @@ -308,7 +308,7 @@ async def test_list_layer_annotations_layer_id_filter_prepended( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_server): """Returns error when the layer does not exist.""" mock_layer_find.return_value = None @@ -326,7 +326,7 @@ async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_serve @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.list") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_list_layer_annotations_only_returns_own_layer( mock_list, mock_layer_find, mcp_server ): @@ -360,7 +360,7 @@ async def test_list_layer_annotations_only_returns_own_layer( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_found( mock_ann_find, mock_layer_find, mcp_server ): @@ -381,7 +381,7 @@ async def test_get_layer_annotation_info_found( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_server): """Returns error when the layer does not exist.""" mock_layer_find.return_value = None @@ -399,7 +399,7 @@ async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_se @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_annotation_not_found( mock_ann_find, mock_layer_find, mcp_server ): @@ -420,7 +420,7 @@ async def test_get_layer_annotation_info_annotation_not_found( @patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") @patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") -@pytest.mark.asyncio() +@pytest.mark.asyncio async def test_get_layer_annotation_info_wrong_layer( mock_ann_find, mock_layer_find, mcp_server ):