diff --git a/superset/daos/tag.py b/superset/daos/tag.py index e4aa89181644d..46a1d2538f16a 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -51,14 +51,29 @@ def create_custom_tagged_objects( object_type: ObjectType, object_id: int, tag_names: list[str] ) -> None: tagged_objects = [] - for name in tag_names: + + # striping and de-dupping + clean_tag_names: set[str] = {tag.strip() for tag in tag_names} + + for name in clean_tag_names: type_ = TagType.custom - tag_name = name.strip() - tag = TagDAO.get_by_name(tag_name, type_) + tag = TagDAO.get_by_name(name, type_) tagged_objects.append( TaggedObject(object_id=object_id, object_type=object_type, tag=tag) ) + # Check if the association already exists + existing_tagged_object = ( + db.session.query(TaggedObject) + .filter_by(object_id=object_id, object_type=object_type, tag=tag) + .first() + ) + + if not existing_tagged_object: + tagged_objects.append( + TaggedObject(object_id=object_id, object_type=object_type, tag=tag) + ) + db.session.add_all(tagged_objects) db.session.commit() diff --git a/superset/migrations/__init__.py b/superset/migrations/__init__.py index 13a83393a9124..b083f44bb43e7 100644 --- a/superset/migrations/__init__.py +++ b/superset/migrations/__init__.py @@ -14,3 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + +import os +import sys + +# hack to be able to import / reuse migration_utils.py in revisions +module_dir = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(module_dir) diff --git a/superset/migrations/migration_utils.py b/superset/migrations/migration_utils.py new file mode 100644 index 0000000000000..c754669a1af69 --- /dev/null +++ b/superset/migrations/migration_utils.py @@ -0,0 +1,46 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from alembic.operations import BatchOperations, Operations + +naming_convention = { + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", +} + + +def create_unique_constraint( + op: Operations, index_id: str, table_name: str, uix_columns: list[str] +) -> None: + with op.batch_alter_table( + table_name, naming_convention=naming_convention + ) as batch_op: + batch_op.create_unique_constraint(index_id, uix_columns) + + +def drop_unique_constraint(op: Operations, index_id: str, table_name: str) -> None: + dialect = op.get_bind().dialect.name + + with op.batch_alter_table( + table_name, naming_convention=naming_convention + ) as batch_op: + if dialect == "mysql": + # MySQL requires specifying the type of constraint + batch_op.drop_constraint(index_id, type_="unique") + else: + # For other databases, a standard drop_constraint call is sufficient + batch_op.drop_constraint(index_id) diff --git a/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py b/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py new file mode 100644 index 0000000000000..0b67ad5024f75 --- /dev/null +++ b/superset/migrations/versions/2024-01-17_13-09_96164e3017c6_tagged_object_unique_constraint.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import enum + +import migration_utils as utils +import sqlalchemy as sa +from alembic import op +from sqlalchemy import Column, Enum, Integer, MetaData, Table +from sqlalchemy.sql import and_, func, select + +# revision identifiers, used by Alembic. +revision = "96164e3017c6" +down_revision = "59a1450b3c10" + + +class ObjectType(enum.Enum): + # pylint: disable=invalid-name + query = 1 + chart = 2 + dashboard = 3 + dataset = 4 + + +# Define the tagged_object table structure +metadata = MetaData() +tagged_object_table = Table( + "tagged_object", + metadata, + Column("id", Integer, primary_key=True), + Column("tag_id", Integer), + Column("object_id", Integer), + Column("object_type", Enum(ObjectType)), # Replace ObjectType with your Enum +) + +index_id = "uix_tagged_object" +table_name = "tagged_object" +uix_columns = ["tag_id", "object_id", "object_type"] + + +def upgrade(): + bind = op.get_bind() # Get the database connection bind + + # Reflect the current database state to get existing tables + metadata.reflect(bind=bind) + + # Delete duplicates if any + min_id_subquery = ( + select( + [ + func.min(tagged_object_table.c.id).label("min_id"), + tagged_object_table.c.tag_id, + tagged_object_table.c.object_id, + tagged_object_table.c.object_type, + ] + ) + .group_by( + tagged_object_table.c.tag_id, + tagged_object_table.c.object_id, + tagged_object_table.c.object_type, + ) + .alias("min_ids") + ) + + delete_query = tagged_object_table.delete().where( + tagged_object_table.c.id.notin_(select([min_id_subquery.c.min_id])) + ) + + bind.execute(delete_query) + + # Create unique constraint + utils.create_unique_constraint(op, index_id, table_name, uix_columns) + + +def downgrade(): + utils.drop_unique_constraint(op, index_id, table_name) diff --git a/superset/migrations/versions/2024-01-18_12-12_15a2c68a2e6b_merging_two_heads.py b/superset/migrations/versions/2024-01-18_12-12_15a2c68a2e6b_merging_two_heads.py new file mode 100644 index 0000000000000..7904d9298df47 --- /dev/null +++ b/superset/migrations/versions/2024-01-18_12-12_15a2c68a2e6b_merging_two_heads.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""merging two heads + +Revision ID: 15a2c68a2e6b +Revises: ('96164e3017c6', 'a32e0c4d8646') +Create Date: 2024-01-18 12:12:52.174742 + +""" + +# revision identifiers, used by Alembic. +revision = "15a2c68a2e6b" +down_revision = ("96164e3017c6", "a32e0c4d8646") + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/superset/migrations/versions/2024-01-19_08-42_1cf8e4344e2b_merging.py b/superset/migrations/versions/2024-01-19_08-42_1cf8e4344e2b_merging.py new file mode 100644 index 0000000000000..9ac2a9b24ff5d --- /dev/null +++ b/superset/migrations/versions/2024-01-19_08-42_1cf8e4344e2b_merging.py @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""merging + +Revision ID: 1cf8e4344e2b +Revises: ('e863403c0c50', '15a2c68a2e6b') +Create Date: 2024-01-19 08:42:37.694192 + +""" + +# revision identifiers, used by Alembic. +revision = "1cf8e4344e2b" +down_revision = ("e863403c0c50", "15a2c68a2e6b") + +import sqlalchemy as sa +from alembic import op + + +def upgrade(): + pass + + +def downgrade(): + pass diff --git a/superset/tags/models.py b/superset/tags/models.py index bae4417507bd4..1e8ca7de1a332 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -21,10 +21,21 @@ from flask import escape from flask_appbuilder import Model -from sqlalchemy import Column, Enum, ForeignKey, Integer, orm, String, Table, Text +from sqlalchemy import ( + Column, + Enum, + exists, + ForeignKey, + Integer, + orm, + String, + Table, + Text, +) from sqlalchemy.engine.base import Connection from sqlalchemy.orm import relationship, sessionmaker from sqlalchemy.orm.mapper import Mapper +from sqlalchemy.schema import UniqueConstraint from superset import security_manager from superset.models.helpers import AuditMixinNullable @@ -110,6 +121,14 @@ class TaggedObject(Model, AuditMixinNullable): object_type = Column(Enum(ObjectType)) tag = relationship("Tag", back_populates="objects", overlaps="tags") + __table_args__ = ( + UniqueConstraint( + "tag_id", "object_id", "object_type", name="uix_tagged_object" + ), + ) + + def __str__(self) -> str: + return f"" def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag: @@ -138,7 +157,7 @@ def get_object_type(class_name: str) -> ObjectType: class ObjectUpdater: - object_type: str | None = None + object_type: str = "default" @classmethod def get_owners_ids( @@ -146,6 +165,19 @@ def get_owners_ids( ) -> list[int]: raise NotImplementedError("Subclass should implement `get_owners_ids`") + @classmethod + def get_owner_tag_ids( + cls, + session: orm.Session, + target: Dashboard | FavStar | Slice | Query | SqlaTable, + ) -> set[int]: + tag_ids = set() + for owner_id in cls.get_owners_ids(target): + name = f"owner:{owner_id}" + tag = get_tag(name, session, TagType.owner) + tag_ids.add(tag.id) + return tag_ids + @classmethod def _add_owners( cls, @@ -153,10 +185,28 @@ def _add_owners( target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: for owner_id in cls.get_owners_ids(target): - name = f"owner:{owner_id}" + name: str = f"owner:{owner_id}" tag = get_tag(name, session, TagType.owner) + cls.add_tag_object_if_not_tagged( + session, tag_id=tag.id, object_id=target.id, object_type=cls.object_type + ) + + @classmethod + def add_tag_object_if_not_tagged( + cls, session: orm.Session, tag_id: int, object_id: int, object_type: str + ) -> None: + # Check if the object is already tagged + exists_query = exists().where( + TaggedObject.tag_id == tag_id, + TaggedObject.object_id == object_id, + TaggedObject.object_type == object_type, + ) + already_tagged = session.query(exists_query).scalar() + + # Add TaggedObject to the session if it isn't already tagged + if not already_tagged: tagged_object = TaggedObject( - tag_id=tag.id, object_id=target.id, object_type=cls.object_type + tag_id=tag_id, object_id=object_id, object_type=object_type ) session.add(tagged_object) @@ -173,10 +223,9 @@ def after_insert( # add `type:` tags tag = get_tag(f"type:{cls.object_type}", session, TagType.type) - tagged_object = TaggedObject( - tag_id=tag.id, object_id=target.id, object_type=cls.object_type + cls.add_tag_object_if_not_tagged( + session, tag_id=tag.id, object_id=target.id, object_type=cls.object_type ) - session.add(tagged_object) session.commit() @classmethod @@ -187,23 +236,35 @@ def after_update( target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: with Session(bind=connection) as session: - # delete current `owner:` tags - query = ( - session.query(TaggedObject.id) + # Fetch current owner tags + existing_tags = ( + session.query(TaggedObject) .join(Tag) .filter( TaggedObject.object_type == cls.object_type, TaggedObject.object_id == target.id, Tag.type == TagType.owner, ) + .all() ) - ids = [row[0] for row in query] - session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( - synchronize_session=False - ) + existing_owner_tag_ids = {tag.tag_id for tag in existing_tags} - # add `owner:` tags - cls._add_owners(session, target) + # Determine new owner IDs + new_owner_tag_ids = cls.get_owner_tag_ids(session, target) + + # Add missing tags + for owner_tag_id in new_owner_tag_ids - existing_owner_tag_ids: + tagged_object = TaggedObject( + tag_id=owner_tag_id, + object_id=target.id, + object_type=cls.object_type, + ) + session.add(tagged_object) + + # Remove unnecessary tags + for tag in existing_tags: + if tag.tag_id not in new_owner_tag_ids: + session.delete(tag) session.commit() @classmethod diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index 863288a3e73ec..d79261c2a3022 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -577,15 +577,25 @@ def test_post_bulk_tag(self): result = TagDAO.get_tagged_objects_for_tags(tags, ["chart"]) assert len(result) == 1 - tagged_objects = db.session.query(TaggedObject).filter( - TaggedObject.object_id == dashboard.id, - TaggedObject.object_type == ObjectType.dashboard, + tagged_objects = ( + db.session.query(TaggedObject) + .join(Tag) + .filter( + TaggedObject.object_id == dashboard.id, + TaggedObject.object_type == ObjectType.dashboard, + Tag.type == TagType.custom, + ) ) assert tagged_objects.count() == 2 - tagged_objects = db.session.query(TaggedObject).filter( - TaggedObject.object_id == chart.id, - TaggedObject.object_type == ObjectType.chart, + tagged_objects = ( + db.session.query(TaggedObject) + .join(Tag) + .filter( + TaggedObject.object_id == chart.id, + TaggedObject.object_type == ObjectType.chart, + Tag.type == TagType.custom, + ) ) assert tagged_objects.count() == 2 diff --git a/tests/integration_tests/tags/commands_tests.py b/tests/integration_tests/tags/commands_tests.py index 83762f8f6e876..3644c076e6a3d 100644 --- a/tests/integration_tests/tags/commands_tests.py +++ b/tests/integration_tests/tags/commands_tests.py @@ -63,7 +63,7 @@ def test_create_custom_tag_command(self): example_dashboard = ( db.session.query(Dashboard).filter_by(slug="world_health").one() ) - example_tags = ["create custom tag example 1", "create custom tag example 2"] + example_tags = {"create custom tag example 1", "create custom tag example 2"} command = CreateCustomTagCommand( ObjectType.dashboard.value, example_dashboard.id, example_tags ) @@ -78,7 +78,7 @@ def test_create_custom_tag_command(self): ) .all() ) - assert example_tags == [tag.name for tag in created_tags] + assert example_tags == {tag.name for tag in created_tags} # cleanup tags = db.session.query(Tag).filter(Tag.name.in_(example_tags)) @@ -99,7 +99,7 @@ def test_delete_tags_command(self): .filter_by(dashboard_title="World Bank's Data") .one() ) - example_tags = ["create custom tag example 1", "create custom tag example 2"] + example_tags = {"create custom tag example 1", "create custom tag example 2"} command = CreateCustomTagCommand( ObjectType.dashboard.value, example_dashboard.id, example_tags ) @@ -115,7 +115,7 @@ def test_delete_tags_command(self): .order_by(Tag.name) .all() ) - assert example_tags == [tag.name for tag in created_tags] + assert example_tags == {tag.name for tag in created_tags} command = DeleteTagsCommand(example_tags) command.run() @@ -132,7 +132,7 @@ def test_delete_tags_command(self): example_dashboard = ( db.session.query(Dashboard).filter_by(slug="world_health").one() ) - example_tags = ["create custom tag example 1", "create custom tag example 2"] + example_tags = {"create custom tag example 1", "create custom tag example 2"} command = CreateCustomTagCommand( ObjectType.dashboard.value, example_dashboard.id, example_tags ) @@ -152,7 +152,7 @@ def test_delete_tags_command(self): command = DeleteTaggedObjectCommand( object_type=ObjectType.dashboard.value, object_id=example_dashboard.id, - tag=example_tags[0], + tag=list(example_tags)[0], ) command.run() tagged_objects = (