diff --git a/superset/daos/tag.py b/superset/daos/tag.py index c063657ea09af..9ec5755bebc81 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -394,7 +394,11 @@ def create_tag_relationship( updated_tagged_objects = { (to_object_type(obj[0]), obj[1]) for obj in objects_to_tag } - tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects + + if not objects_to_tag: + tagged_objects_to_delete = current_tagged_objects + else: + tagged_objects_to_delete = current_tagged_objects - updated_tagged_objects for object_type, object_id in updated_tagged_objects: # create rows for new objects, and skip tags that already exist diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index e8311ad520be4..94c262dd26dd2 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -67,25 +67,23 @@ def validate(self) -> None: class CreateCustomTagWithRelationshipsCommand(CreateMixin, BaseCommand): def __init__(self, data: dict[str, Any], bulk_create: bool = False): - self._tag = data["name"] - self._objects_to_tag = data.get("objects_to_tag") - self._description = data.get("description") + self._properties = data.copy() self._bulk_create = bulk_create def run(self) -> None: self.validate() try: - tag = TagDAO.get_by_name(self._tag.strip(), TagTypes.custom) - if self._objects_to_tag: - TagDAO.create_tag_relationship( - objects_to_tag=self._objects_to_tag, - tag=tag, - bulk_create=self._bulk_create, - ) + tag_name = self._properties["name"] + tag = TagDAO.get_by_name(tag_name.strip(), TagTypes.custom) + TagDAO.create_tag_relationship( + objects_to_tag=self._properties.get("objects_to_tag", []), + tag=tag, + bulk_create=self._bulk_create, + ) - if self._description: - tag.description = self._description + if description := self._properties.get("description"): + tag.description = description db.session.commit() @@ -96,13 +94,13 @@ def run(self) -> None: def validate(self) -> None: exceptions = [] # Validate object_id - if self._objects_to_tag: - if any(obj_id == 0 for obj_type, obj_id in self._objects_to_tag): + if objects_to_tag := self._properties.get("objects_to_tag", []): + if any(obj_id == 0 for obj_type, obj_id in objects_to_tag): exceptions.append(TagInvalidError()) # Validate object type skipped_tagged_objects: list[tuple[str, int]] = [] - for obj_type, obj_id in self._objects_to_tag: + for obj_type, obj_id in objects_to_tag: skipped_tagged_objects = [] object_type = to_object_type(obj_type) @@ -117,7 +115,7 @@ def validate(self) -> None: # skip the object if the user doesn't have access skipped_tagged_objects.append((obj_type, obj_id)) - self._objects_to_tag = set(self._objects_to_tag) - set( + self._properties["objects_to_tag"] = set(objects_to_tag) - set( skipped_tagged_objects ) diff --git a/superset/tags/commands/update.py b/superset/tags/commands/update.py index a13e4e8e7bbb0..597b1980f65c2 100644 --- a/superset/tags/commands/update.py +++ b/superset/tags/commands/update.py @@ -38,12 +38,10 @@ def __init__(self, model_id: int, data: dict[str, Any]): def run(self) -> Model: self.validate() if self._model: - if self._properties.get("objects_to_tag"): - # todo(hugh): can this manage duplication - TagDAO.create_tag_relationship( - objects_to_tag=self._properties["objects_to_tag"], - tag=self._model, - ) + TagDAO.create_tag_relationship( + objects_to_tag=self._properties["objects_to_tag"], + tag=self._model, + ) if description := self._properties.get("description"): self._model.description = description if tag_name := self._properties.get("name"): diff --git a/superset/tags/schemas.py b/superset/tags/schemas.py index 571a2a03c9e74..a91a5709c16bf 100644 --- a/superset/tags/schemas.py +++ b/superset/tags/schemas.py @@ -58,7 +58,7 @@ class TagObjectSchema(Schema): name = fields.String() description = fields.String(required=False, allow_none=True) objects_to_tag = fields.List( - fields.Tuple((fields.String(), fields.Int())), required=False + fields.Tuple((fields.String(), fields.Int())), )