Skip to content

Commit

Permalink
feat(tags): handle case-insensitive tags and remove orphans (#1937)
Browse files Browse the repository at this point in the history
Merge pull request #1937 from AntaresSimulatorTeam/bugfix/handle-case-insensitive-tags
ANT-940
  • Loading branch information
laurent-laporte-pro authored Feb 23, 2024
2 parents 721d18f + dc252a1 commit 3396f2e
Show file tree
Hide file tree
Showing 7 changed files with 203 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import itertools
import json
import secrets
import typing as t

import sqlalchemy as sa # type: ignore
from alembic import op
Expand All @@ -23,6 +24,22 @@
depends_on = None


def _avoid_duplicates(tags: t.Iterable[str]) -> t.Sequence[str]:
"""Avoid duplicate tags (case insensitive)"""

upper_tags = {tag.upper(): tag for tag in tags}
return list(upper_tags.values())


def _load_patch_obj(patch: t.Optional[str]) -> t.MutableMapping[str, t.Any]:
"""Load the patch object from the `patch` field in the `study_additional_data` table."""

obj: t.MutableMapping[str, t.Any] = json.loads(patch or "{}")
obj["study"] = obj.get("study") or {}
obj["study"]["tags"] = _avoid_duplicates(obj["study"].get("tags") or [])
return obj


def upgrade() -> None:
"""
Populate `tag` and `study_tag` tables from `patch` field in `study_additional_data` table
Expand All @@ -39,27 +56,31 @@ def upgrade() -> None:
connexion: Connection = op.get_bind()

# retrieve the tags and the study-tag pairs from the db
study_tags = connexion.execute("SELECT study_id,patch FROM study_additional_data")
tags_by_ids = {}
study_tags = connexion.execute("SELECT study_id, patch FROM study_additional_data")
tags_by_ids: t.MutableMapping[str, t.Set[str]] = {}
for study_id, patch in study_tags:
obj = json.loads(patch or "{}")
study = obj.get("study") or {}
tags = frozenset(study.get("tags") or ())
tags_by_ids[study_id] = tags
obj = _load_patch_obj(patch)
tags_by_ids[study_id] = obj["study"]["tags"]

# delete rows in tables `tag` and `study_tag`
connexion.execute("DELETE FROM study_tag")
connexion.execute("DELETE FROM tag")

# insert the tags in the `tag` table
labels = set(itertools.chain.from_iterable(tags_by_ids.values()))
bulk_tags = [{"label": label, "color": secrets.choice(COLOR_NAMES)} for label in labels]
all_labels = {lbl.upper(): lbl for lbl in itertools.chain.from_iterable(tags_by_ids.values())}
bulk_tags = [{"label": label, "color": secrets.choice(COLOR_NAMES)} for label in all_labels.values()]
if bulk_tags:
sql = sa.text("INSERT INTO tag (label, color) VALUES (:label, :color)")
connexion.execute(sql, *bulk_tags)

# Create relationships between studies and tags in the `study_tag` table
bulk_study_tags = [{"study_id": id_, "tag_label": lbl} for id_, tags in tags_by_ids.items() for lbl in tags]
bulk_study_tags = [
# fmt: off
{"study_id": id_, "tag_label": all_labels[lbl.upper()]}
for id_, tags in tags_by_ids.items()
for lbl in tags
# fmt: on
]
if bulk_study_tags:
sql = sa.text("INSERT INTO study_tag (study_id, tag_label) VALUES (:study_id, :tag_label)")
connexion.execute(sql, *bulk_study_tags)
Expand All @@ -78,7 +99,7 @@ def downgrade() -> None:
connexion: Connection = op.get_bind()

# Creating the `tags_by_ids` mapping from data in the `study_tags` table
tags_by_ids = collections.defaultdict(set)
tags_by_ids: t.MutableMapping[str, t.Set[str]] = collections.defaultdict(set)
study_tags = connexion.execute("SELECT study_id, tag_label FROM study_tag")
for study_id, tag_label in study_tags:
tags_by_ids[study_id].add(tag_label)
Expand All @@ -87,10 +108,8 @@ def downgrade() -> None:
objects_by_ids = {}
study_tags = connexion.execute("SELECT study_id, patch FROM study_additional_data")
for study_id, patch in study_tags:
obj = json.loads(patch or "{}")
obj["study"] = obj.get("study") or {}
obj["study"]["tags"] = obj["study"].get("tags") or []
obj["study"]["tags"] = sorted(tags_by_ids[study_id] | set(obj["study"]["tags"]))
obj = _load_patch_obj(patch)
obj["study"]["tags"] = _avoid_duplicates(tags_by_ids[study_id] | set(obj["study"]["tags"]))
objects_by_ids[study_id] = obj

# Updating objects in the `study_additional_data` table
Expand Down
15 changes: 13 additions & 2 deletions antarest/study/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from datetime import datetime, timedelta
from pathlib import Path

from pydantic import BaseModel
from pydantic import BaseModel, validator
from sqlalchemy import ( # type: ignore
Boolean,
Column,
Expand Down Expand Up @@ -351,7 +351,18 @@ class StudyMetadataPatchDTO(BaseModel):
scenario: t.Optional[str] = None
status: t.Optional[str] = None
doc: t.Optional[str] = None
tags: t.List[str] = []
tags: t.Sequence[str] = ()

@validator("tags", each_item=True)
def _normalize_tags(cls, v: str) -> str:
"""Remove leading and trailing whitespaces, and replace consecutive whitespaces by a single one."""
tag = " ".join(v.split())
if not tag:
raise ValueError("Tag cannot be empty")
elif len(tag) > 40:
raise ValueError(f"Tag is too long: {tag!r}")
else:
return tag


class StudySimSettingsDTO(BaseModel):
Expand Down
23 changes: 16 additions & 7 deletions antarest/study/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def get_all(
if study_filter.groups:
q = q.join(entity.groups).filter(Group.id.in_(study_filter.groups))
if study_filter.tags:
q = q.join(entity.tags).filter(Tag.label.in_(study_filter.tags))
upper_tags = [tag.upper() for tag in study_filter.tags]
q = q.join(entity.tags).filter(func.upper(Tag.label).in_(upper_tags))
if study_filter.archived is not None:
q = q.filter(entity.archived == study_filter.archived)
if study_filter.name:
Expand Down Expand Up @@ -279,17 +280,25 @@ def delete(self, id_: str, *ids: str) -> None:
def update_tags(self, study: Study, new_tags: t.Sequence[str]) -> None:
"""
Updates the tags associated with a given study in the database,
replacing existing tags with new ones.
replacing existing tags with new ones (case-insensitive).
Args:
study: The pre-existing study to be updated with the new tags.
new_tags: The new tags to be associated with the input study in the database.
"""
existing_tags = self.session.query(Tag).filter(Tag.label.in_(new_tags)).all()
new_labels = set(new_tags) - set([tag.label for tag in existing_tags])
study.tags = [Tag(label=tag) for tag in new_labels] + existing_tags
self.session.merge(study)
self.session.commit()
new_upper_tags = {tag.upper(): tag for tag in new_tags}
session = self.session
existing_tags = session.query(Tag).filter(func.upper(Tag.label).in_(new_upper_tags)).all()
for tag in existing_tags:
if tag.label.upper() in new_upper_tags:
new_upper_tags.pop(tag.label.upper())
study.tags = [Tag(label=tag) for tag in new_upper_tags.values()] + existing_tags
session.merge(study)
session.commit()
# Delete any tag that is not associated with any study.
# Note: If tags are to be associated with objects other than Study, this code must be updated.
session.query(Tag).filter(~Tag.studies.any()).delete(synchronize_session=False) # type: ignore
session.commit()

def list_duplicates(self) -> t.List[t.Tuple[str, str]]:
"""
Expand Down
15 changes: 8 additions & 7 deletions tests/integration/studies_blueprint/test_get_studies.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_study_listing(
task = wait_task_completion(client, admin_access_token, archiving_study_task_id)
assert task.status == TaskStatus.COMPLETED, task

# create a raw study version 840 to be tagged with `winter_transition`
# create a raw study version 840 to be tagged with `Winter_Transition`
res = client.post(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
Expand All @@ -330,7 +330,7 @@ def test_study_listing(
res = client.put(
f"{STUDIES_URL}/{tagged_raw_840_id}",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"tags": ["winter_transition"]},
json={"tags": ["Winter_Transition"]},
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
res = client.get(
Expand All @@ -341,7 +341,7 @@ def test_study_listing(
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map: t.Dict[str, t.Dict[str, t.Any]] = res.json()
assert len(study_map) == 1
assert set(study_map[tagged_raw_840_id]["tags"]) == {"winter_transition"}
assert set(study_map[tagged_raw_840_id]["tags"]) == {"Winter_Transition"}

# create a raw study version 850 to be tagged with `decennial`
res = client.post(
Expand Down Expand Up @@ -391,7 +391,8 @@ def test_study_listing(
assert len(study_map) == 1
assert set(study_map[tagged_variant_840_id]["tags"]) == {"decennial"}

# create a variant study version 850 to be tagged with `winter_transition`
# create a variant study version 850 to be tagged with `winter_transition`.
# also test that the tag label is case-insensitive.
res = client.post(
f"{STUDIES_URL}/{tagged_raw_850_id}/variants",
headers={"Authorization": f"Bearer {admin_access_token}"},
Expand All @@ -402,7 +403,7 @@ def test_study_listing(
res = client.put(
f"{STUDIES_URL}/{tagged_variant_850_id}",
headers={"Authorization": f"Bearer {admin_access_token}"},
json={"tags": ["winter_transition"]},
json={"tags": ["winter_transition"]}, # note the tag label is in lower case
)
assert res.status_code in CREATE_STATUS_CODES, res.json()
res = client.get(
Expand All @@ -413,7 +414,7 @@ def test_study_listing(
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map = res.json()
assert len(study_map) == 1
assert set(study_map[tagged_variant_850_id]["tags"]) == {"winter_transition"}
assert set(study_map[tagged_variant_850_id]["tags"]) == {"Winter_Transition"}

# ==========================
# 2. Filtering testing
Expand Down Expand Up @@ -670,7 +671,7 @@ def test_study_listing(
res = client.get(
STUDIES_URL,
headers={"Authorization": f"Bearer {admin_access_token}"},
params={"tags": "decennial"},
params={"tags": "DECENNIAL"},
)
assert res.status_code == LIST_STATUS_CODE, res.json()
study_map = res.json()
Expand Down
95 changes: 95 additions & 0 deletions tests/integration/studies_blueprint/test_update_tags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from starlette.testclient import TestClient


class TestupdateStudyMetadata:
"""
Test the study tags update through the `update_study_metadata` API endpoint.
"""

def test_update_tags(
self,
client: TestClient,
user_access_token: str,
study_id: str,
) -> None:
"""
This test verifies that we can update the tags of a study.
It also tests the tags normalization.
"""

# Classic usage: set some tags to a study
study_tags = ["Tag1", "Tag2"]
res = client.put(
f"/v1/studies/{study_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
json={"tags": study_tags},
)
assert res.status_code == 200, res.json()
actual = res.json()
assert set(actual["tags"]) == set(study_tags)

# Update the tags with already existing tags (case-insensitive):
# - "Tag1" is preserved, but with the same case as the existing one.
# - "Tag2" is replaced by "Tag3".
study_tags = ["tag1", "Tag3"]
res = client.put(
f"/v1/studies/{study_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
json={"tags": study_tags},
)
assert res.status_code == 200, res.json()
actual = res.json()
assert set(actual["tags"]) != set(study_tags) # not the same case
assert set(tag.upper() for tag in actual["tags"]) == {"TAG1", "TAG3"}

# String normalization: whitespaces are stripped and
# consecutive whitespaces are replaced by a single one.
study_tags = [" \xa0Foo \t Bar \n ", " \t Baz\xa0\xa0"]
res = client.put(
f"/v1/studies/{study_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
json={"tags": study_tags},
)
assert res.status_code == 200, res.json()
actual = res.json()
assert set(actual["tags"]) == {"Foo Bar", "Baz"}

# We can have symbols in the tags
study_tags = ["Foo-Bar", ":Baz%"]
res = client.put(
f"/v1/studies/{study_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
json={"tags": study_tags},
)
assert res.status_code == 200, res.json()
actual = res.json()
assert set(actual["tags"]) == {"Foo-Bar", ":Baz%"}

def test_update_tags__invalid_tags(
self,
client: TestClient,
user_access_token: str,
study_id: str,
) -> None:
# We cannot have empty tags
study_tags = [""]
res = client.put(
f"/v1/studies/{study_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
json={"tags": study_tags},
)
assert res.status_code == 422, res.json()
description = res.json()["description"]
assert "Tag cannot be empty" in description

# We cannot have tags longer than 40 characters
study_tags = ["very long tags, very long tags, very long tags"]
assert len(study_tags[0]) > 40
res = client.put(
f"/v1/studies/{study_id}",
headers={"Authorization": f"Bearer {user_access_token}"},
json={"tags": study_tags},
)
assert res.status_code == 422, res.json()
description = res.json()["description"]
assert "Tag is too long" in description
36 changes: 35 additions & 1 deletion tests/study/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@ def test_repository_get_all__study_tags_filter(

test_tag_1 = Tag(label="hidden-tag")
test_tag_2 = Tag(label="decennial")
test_tag_3 = Tag(label="winter_transition")
test_tag_3 = Tag(label="Winter_Transition") # note the different case

study_1 = VariantStudy(id=1, tags=[test_tag_1])
study_2 = VariantStudy(id=2, tags=[test_tag_2])
Expand All @@ -655,7 +655,41 @@ def test_repository_get_all__study_tags_filter(
_ = [s.groups for s in all_studies]
_ = [s.additional_data for s in all_studies]
_ = [s.tags for s in all_studies]

assert len(db_recorder.sql_statements) == 1, str(db_recorder)

if expected_ids is not None:
assert {s.id for s in all_studies} == expected_ids


def test_update_tags(
db_session: Session,
) -> None:
icache: Mock = Mock(spec=ICache)
repository = StudyMetadataRepository(cache_service=icache, session=db_session)

study_id = 1
study = RawStudy(id=study_id, tags=[])
db_session.add(study)
db_session.commit()

# use the db recorder to check that:
# 1- finding existing tags requires 1 query
# 2- updating the study tags requires 4 queries (2 selects, 2 inserts)
# 3- deleting orphan tags requires 1 query
with DBStatementRecorder(db_session.bind) as db_recorder:
repository.update_tags(study, ["Tag1", "Tag2"])
assert len(db_recorder.sql_statements) == 6, str(db_recorder)

# Check that when we change the tags to ["TAG1", "Tag3"],
# "Tag1" is preserved, "Tag2" is deleted and "Tag3" is created
# 1- finding existing tags requires 1 query
# 2- updating the study tags requires 4 queries (2 selects, 2 inserts, 1 delete)
# 3- deleting orphan tags requires 1 query
with DBStatementRecorder(db_session.bind) as db_recorder:
repository.update_tags(study, ["TAG1", "Tag3"])
assert len(db_recorder.sql_statements) == 7, str(db_recorder)

# Check that only "Tag1" and "Tag3" are present in the database
tags = db_session.query(Tag).all()
assert {tag.label for tag in tags} == {"Tag1", "Tag3"}
4 changes: 3 additions & 1 deletion webapp/src/utils/studiesUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ const tagsPredicate = R.curry(
if (!study.tags || study.tags.length === 0) {
return false;
}
return R.intersection(study.tags, tags).length > 0;
const upperCaseTags = tags.map((tag) => tag.toUpperCase());
const upperCaseStudyTags = study.tags.map((tag) => tag.toUpperCase());
return R.intersection(upperCaseStudyTags, upperCaseTags).length > 0;
},
);

Expand Down

0 comments on commit 3396f2e

Please sign in to comment.