Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions superset/commands/tag/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,35 @@ def validate(self) -> None:
exceptions.append(
TagCreateFailedError(f"invalid object type {self._object_type}")
)

# Validate user has access to the target object
if object_type:
self._validate_object_access(object_type, self._object_id, exceptions)

if exceptions:
raise TagInvalidError(exceptions=exceptions)

def _validate_object_access(
self, object_type: ObjectType, object_id: int, exceptions: list[Any]
) -> None:
"""Validate that the current user has access to the target object."""
try:
target_object = to_object_model(object_type, object_id)
if target_object is None:
exceptions.append(
TagCreateFailedError(f"Access denied for {object_type} {object_id}")
)
return
if hasattr(target_object, "raise_for_access"):
target_object.raise_for_access()
except SupersetSecurityException:
exceptions.append(
TagCreateFailedError(f"Access denied for {object_type} {object_id}")
)
except AttributeError:
# No request context (e.g. background task) — skip access check
pass


class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any], bulk_create: bool = False):
Expand Down
28 changes: 27 additions & 1 deletion superset/commands/tag/delete.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
import logging
from functools import partial
from typing import Any

from superset.commands.base import BaseCommand
from superset.commands.tag.exceptions import (
Expand All @@ -25,8 +26,9 @@
TagInvalidError,
TagNotFoundError,
)
from superset.commands.tag.utils import to_object_type
from superset.commands.tag.utils import to_object_model, to_object_type
from superset.daos.tag import TagDAO
from superset.exceptions import SupersetSecurityException
from superset.tags.models import ObjectType
from superset.utils.decorators import on_error, transaction
from superset.views.base import DeleteMixin
Expand Down Expand Up @@ -71,6 +73,9 @@ def validate(self) -> None:
)
)
else:
# Validate user has access to the target object
self._validate_object_access(object_type, self._object_id, exceptions)

tagged_object = TagDAO.find_tagged_object(
object_type=object_type, object_id=self._object_id, tag_id=tag.id
)
Expand All @@ -85,6 +90,27 @@ def validate(self) -> None:
if exceptions:
raise TagInvalidError(exceptions=exceptions)

def _validate_object_access(
self, object_type: ObjectType, object_id: int, exceptions: list[Any]
) -> None:
"""Validate that the current user has access to the target object."""
try:
target_object = to_object_model(object_type, object_id)
if target_object is None:
# Object may have been deleted; allow tag cleanup
return
if hasattr(target_object, "raise_for_access"):
target_object.raise_for_access()
except SupersetSecurityException:
exceptions.append(
TaggedObjectDeleteFailedError(
f"Access denied for {object_type} {object_id}"
)
)
except AttributeError:
# No request context (e.g. background task) — skip access check
pass


class DeleteTagsCommand(DeleteMixin, BaseCommand):
def __init__(self, tags: list[str]):
Expand Down
26 changes: 16 additions & 10 deletions superset/commands/tag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

from typing import Optional, Union

from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.query import SavedQueryDAO
from superset import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
Expand All @@ -38,10 +36,18 @@ def to_object_type(object_type: Union[ObjectType, int, str]) -> Optional[ObjectT
def to_object_model(
object_type: ObjectType, object_id: int
) -> Optional[Union[Dashboard, SavedQuery, Slice]]:
if ObjectType.dashboard == object_type:
return DashboardDAO.find_by_id(object_id)
if ObjectType.query == object_type:
return SavedQueryDAO.find_by_id(object_id)
if ObjectType.chart == object_type:
return ChartDAO.find_by_id(object_id)
return None
"""Load a model instance by type and id.

Uses db.session.get() instead of DAO.find_by_id() to avoid DAO base
filters that require request context. Authorization is enforced by the
caller via raise_for_access() on the returned object.
"""
model_map: dict[ObjectType, type] = {
ObjectType.dashboard: Dashboard,
ObjectType.query: SavedQuery,
ObjectType.chart: Slice,
}
model_cls = model_map.get(object_type)
if model_cls is None:
return None
return db.session.get(model_cls, object_id)
4 changes: 2 additions & 2 deletions tests/integration_tests/tags/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,8 @@ def test_post_bulk_tag_skipped_tags_perm(self):

assert rv.status_code == 200
result = rv.json["result"]
assert len(result["objects_tagged"]) == 2
assert len(result["objects_skipped"]) == 1
assert len(result["objects_tagged"]) == 1
assert len(result["objects_skipped"]) == 2

def test_create_tag_mysql_compatibility(self) -> None:
"""
Expand Down
Loading