diff --git a/services/fixtures/exclusion_words.json b/services/fixtures/exclusion_words.json new file mode 100644 index 000000000..dd36151f8 --- /dev/null +++ b/services/fixtures/exclusion_words.json @@ -0,0 +1,18 @@ +[ + { + "model": "services.exclusionword", + "pk": 1, + "fields": { + "word": "katu", + "language_short": "fi" + } + }, + { + "model": "services.exclusionword", + "pk": 2, + "fields": { + "word": "tie", + "language_short": "fi" + } + } +] \ No newline at end of file diff --git a/services/management/commands/index_search_columns.py b/services/management/commands/index_search_columns.py index 11f9820c0..6ce9ca3be 100644 --- a/services/management/commands/index_search_columns.py +++ b/services/management/commands/index_search_columns.py @@ -1,11 +1,14 @@ import logging +from datetime import datetime, timedelta from django.contrib.postgres.search import SearchVector from django.core.management.base import BaseCommand +from django.utils import timezone from munigeo.models import Address, AdministrativeDivision from services.models import Service, ServiceNode, Unit -from services.search.utils import hyphenate +from services.search.constants import HYPHENATE_ADDRESSES_MODIFIED_WITHIN_DAYS +from services.search.utils import get_foreign_key_attr, hyphenate logger = logging.getLogger("services.management") @@ -27,17 +30,29 @@ def get_search_column(model, lang): return search_column -def generate_syllables(model): +def generate_syllables( + model, hyphenate_all_addresses=False, hyphenate_addresses_from=None +): """ Generates syllables for the given model. """ # Disable sending of signals model._meta.auto_created = True + save_kwargs = {} num_populated = 0 - for row in model.objects.all(): + if model.__name__ == "Address" and not hyphenate_all_addresses: + save_kwargs["skip_modified_at"] = True + if not hyphenate_addresses_from: + hyphenate_addresses_from = Address.objects.latest( + "modified_at" + ).modified_at - timedelta(days=HYPHENATE_ADDRESSES_MODIFIED_WITHIN_DAYS) + qs = model.objects.filter(modified_at__gte=hyphenate_addresses_from) + else: + qs = model.objects.all() + for row in qs: row.syllables_fi = [] for column in model.get_syllable_fi_columns(): - row_content = getattr(row, column, None) + row_content = get_foreign_key_attr(row, column) if row_content: # Rows might be of type str or Array, if str # cast to array by splitting. @@ -45,9 +60,10 @@ def generate_syllables(model): row_content = row_content.split() for word in row_content: syllables = hyphenate(word) - for s in syllables: - row.syllables_fi.append(s) - row.save() + if len(syllables) > 1: + for s in syllables: + row.syllables_fi.append(s) + row.save(**save_kwargs) num_populated += 1 # Enable sending of signals model._meta.auto_created = False @@ -85,13 +101,43 @@ def index_servicenodes(lang): class Command(BaseCommand): - def handle(self, *args, **kwargs): + def add_arguments(self, parser): + parser.add_argument( + "--hyphenate_addresses_from", + nargs="?", + type=str, + help="Hyphenate addresses whose modified_at timestamp starts at given timestamp YYYY-MM-DDTHH:MM:SS", + ) + + parser.add_argument( + "--hyphenate_all_addresses", + action="store_true", + help="Hyphenate all addresses", + ) + + def handle(self, *args, **options): + hyphenate_all_addresses = options.get("hyphenate_all_addresses", None) + hyphenate_addresses_from = options.get("hyphenate_addresses_from", None) + + if hyphenate_addresses_from: + try: + hyphenate_addresses_from = timezone.make_aware( + datetime.strptime(hyphenate_addresses_from, "%Y-%m-%dT%H:%M:%S") + ) + except ValueError as err: + raise ValueError(err) for lang in ["fi", "sv", "en"]: key = "search_column_%s" % lang # Only generate syllables for the finnish language if lang == "fi": logger.info(f"Generating syllables for language: {lang}.") logger.info(f"Syllables generated for {generate_syllables(Unit)} Units") + num_populated = generate_syllables( + Address, + hyphenate_all_addresses=hyphenate_all_addresses, + hyphenate_addresses_from=hyphenate_addresses_from, + ) + logger.info(f"Syllables generated for {num_populated} Addresses") logger.info( f"Syllables generated for {generate_syllables(Service)} Services" ) diff --git a/services/migrations/0117_exclusionword.py b/services/migrations/0117_exclusionword.py new file mode 100644 index 000000000..4fae5f0ad --- /dev/null +++ b/services/migrations/0117_exclusionword.py @@ -0,0 +1,37 @@ +# Generated by Django 5.0.6 on 2024-05-20 10:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ("services", "0116_alter_unit_address_postal_full_and_more"), + ] + + operations = [ + migrations.CreateModel( + name="ExclusionWord", + fields=[ + ( + "id", + models.AutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ("word", models.CharField(max_length=100, verbose_name="Word")), + ( + "language_short", + models.CharField(max_length=2, verbose_name="Language short"), + ), + ], + options={ + "verbose_name": "Exclusion word", + "verbose_name_plural": "Exclusion words", + "ordering": ["-id"], + }, + ), + ] diff --git a/services/models/__init__.py b/services/models/__init__.py index 59683c18b..2b74047de 100644 --- a/services/models/__init__.py +++ b/services/models/__init__.py @@ -4,7 +4,7 @@ from .keyword import Keyword from .mobility import MobilityServiceNode from .notification import Announcement, ErrorMessage -from .search_rule import ExclusionRule +from .search_rule import ExclusionRule, ExclusionWord from .service import Service, UnitServiceDetails from .service_mapping import ServiceMapping from .service_node import ServiceNode diff --git a/services/models/search_rule.py b/services/models/search_rule.py index 78c9c32b9..e1f0d8fe1 100644 --- a/services/models/search_rule.py +++ b/services/models/search_rule.py @@ -13,3 +13,16 @@ class Meta: def __str__(self): return "%s : %s" % (self.word, self.exclusion) + + +class ExclusionWord(models.Model): + word = models.CharField(max_length=100, verbose_name=_("Word")) + language_short = models.CharField(max_length=2, verbose_name=_("Language short")) + + class Meta: + ordering = ["-id"] + verbose_name = _("Exclusion word") + verbose_name_plural = _("Exclusion words") + + def __str__(self): + return self.word diff --git a/services/search/api.py b/services/search/api.py index 3d5db1548..ac858a1dd 100644 --- a/services/search/api.py +++ b/services/search/api.py @@ -27,9 +27,10 @@ from drf_spectacular.utils import extend_schema, OpenApiParameter from munigeo import api as munigeo_api from munigeo.models import Address, AdministrativeDivision -from rest_framework import serializers +from rest_framework import serializers, status from rest_framework.exceptions import ParseError from rest_framework.generics import GenericAPIView +from rest_framework.response import Response from services.api import ( TranslatedModelSerializer, @@ -60,6 +61,7 @@ get_preserved_order, get_service_node_results, get_trigram_results, + has_exclusion_word_in_query, set_address_fields, set_service_node_unit_count, set_service_unit_count, @@ -468,6 +470,12 @@ def get(self, request): else: search_query_str = f"{q}:*" + if has_exclusion_word_in_query(q_vals, language_short): + return Response( + f"Search query {q_vals} would return too many results", + status=status.HTTP_400_BAD_REQUEST, + ) + search_fn = "to_tsquery" if use_websearch: exclusions = self.get_search_exclusions(q) diff --git a/services/search/constants.py b/services/search/constants.py index 9d22fdaed..1d63d3b99 100644 --- a/services/search/constants.py +++ b/services/search/constants.py @@ -11,9 +11,11 @@ "Address", ) QUERY_PARAM_TYPE_NAMES = [m.lower() for m in SEARCHABLE_MODEL_TYPE_NAMES] -# None will slice to the end of list, e.g. no limit. +# None will slice to the end of list, i.e. no limit. DEFAULT_MODEL_LIMIT_VALUE = None # The limit value for the search query that search the search_view. "NULL" = no limit DEFAULT_SEARCH_SQL_LIMIT_VALUE = "NULL" DEFAULT_TRIGRAM_THRESHOLD = 0.15 DEFAULT_RANK_THRESHOLD = 0 + +HYPHENATE_ADDRESSES_MODIFIED_WITHIN_DAYS = 7 diff --git a/services/search/tests/conftest.py b/services/search/tests/conftest.py index 1010de99e..9cfa6167b 100644 --- a/services/search/tests/conftest.py +++ b/services/search/tests/conftest.py @@ -12,7 +12,10 @@ ) from rest_framework.test import APIClient -from services.management.commands.index_search_columns import get_search_column +from services.management.commands.index_search_columns import ( + generate_syllables, + get_search_column, +) from services.management.commands.services_import.services import ( update_service_counts, update_service_node_counts, @@ -20,6 +23,8 @@ ) from services.models import ( Department, + ExclusionRule, + ExclusionWord, Service, ServiceNode, Unit, @@ -243,6 +248,15 @@ def addresses(streets, municipality): number=1, full_name="Tarkk'ampujankatu 1", ) + Address.objects.create( + municipality_id=municipality.id, + location=Point(60.44879002342721, 22.283629416961055), + id=7, + street_id=46, + number=1, + full_name="Kellonsoittajankatu 1", + ) + generate_syllables(Address) Address.objects.update(search_column_fi=get_search_column(Address, "fi")) return Address.objects.all() @@ -280,4 +294,17 @@ def streets(): Street.objects.create(id=43, name="Markulantie", municipality_id="helsinki") Street.objects.create(id=44, name="Yliopistonkatu", municipality_id="helsinki") Street.objects.create(id=45, name="Tarkk'ampujankatu", municipality_id="helsinki") + Street.objects.create(id=46, name="Kellonsoittajankatu", municipality_id="helsinki") return Street.objects.all() + + +@pytest.fixture +def exclusion_rules(): + ExclusionRule.objects.create(id=1, word="tekojää", exclusion="-nurmi") + return ExclusionRule.objects.all() + + +@pytest.fixture +def exclusion_words(): + ExclusionWord.objects.create(id=1, word="katu", language_short="fi") + return ExclusionWord.objects.all() diff --git a/services/search/tests/test_api.py b/services/search/tests/test_api.py index 8deb6c597..67b2b82d8 100644 --- a/services/search/tests/test_api.py +++ b/services/search/tests/test_api.py @@ -13,6 +13,8 @@ def test_search( administrative_division, accessibility_shortcoming, municipality, + exclusion_rules, + exclusion_words, ): # Search for "museo" in entities: units,services and servicenods url = reverse("search") + "?q=museo&type=unit,service,servicenode" @@ -120,6 +122,22 @@ def test_search( assert kurrapolku["location"]["type"] == "Point" assert kurrapolku["location"]["coordinates"][0] == 60.479032 assert kurrapolku["location"]["coordinates"][1] == 22.25417 + # Test search with excluded word + url = reverse("search") + "?q=katu" + response = api_client.get(url) + assert response.status_code == 400 + url = reverse("search") + "?q=Katu" + response = api_client.get(url) + assert response.status_code == 400 + url = reverse("search") + "?q=koti katu" + response = api_client.get(url) + assert response.status_code == 400 + # Test search with 'kello' + url = reverse("search") + "?q=kello&type=address" + response = api_client.get(url) + results = response.json()["results"] + assert len(results) == 1 + assert results[0]["name"]["fi"] == "Kellonsoittajankatu 1" # Test address search with apostrophe in query url = reverse("search") + "?q=tarkk'ampujankatu&type=address" response = api_client.get(url) diff --git a/services/search/utils.py b/services/search/utils.py index 6790c203d..8928c85a3 100644 --- a/services/search/utils.py +++ b/services/search/utils.py @@ -1,17 +1,44 @@ +import logging + import libvoikko from django.db import connection from django.db.models import Case, When +from django.db.models.functions import Lower +from rest_framework.exceptions import ParseError -from services.models import ServiceNode, ServiceNodeUnitCount, Unit +from services.models import ( + ExclusionRule, + ExclusionWord, + ServiceNode, + ServiceNodeUnitCount, + Unit, +) from services.search.constants import ( DEFAULT_TRIGRAM_THRESHOLD, SEARCHABLE_MODEL_TYPE_NAMES, ) +logger = logging.getLogger("search") voikko = libvoikko.Voikko("fi") voikko.setNoUglyHyphenation(True) +def get_foreign_key_attr(obj, field): + """Get attr recursively by following foreign key relations + For example: + get_foreign_key_attr( + , "street__name_fi" + ) + """ + fields = field.split("__") + if len(fields) == 1: + return getattr(obj, fields[0], None) + else: + first_field = fields[0] + remaining_fields = "__".join(fields[1:]) + return get_foreign_key_attr(getattr(obj, first_field), remaining_fields) + + def is_compound_word(word): result = voikko.analyze(word) if len(result) == 0: @@ -21,7 +48,7 @@ def is_compound_word(word): def hyphenate(word): """ - Returns a list of syllables of the word if it is a compound word. + Returns a list of syllables of the word, if it is a compound word. """ word = word.strip() if is_compound_word(word): @@ -199,15 +226,42 @@ def get_preserved_order(ids): def get_trigram_results( model, model_name, field, q_val, threshold=DEFAULT_TRIGRAM_THRESHOLD ): - sql = f"""SELECT id, similarity({field}, '{q_val}') AS sml + sql = f"""SELECT id, similarity({field}, %s) AS sml FROM {model_name} - WHERE similarity({field}, '{q_val}') >= {threshold} + WHERE similarity({field}, %s) >= {threshold} ORDER BY sml DESC; """ cursor = connection.cursor() - cursor.execute(sql) + try: + cursor.execute(sql, [q_val, q_val]) + except Exception as e: + logger.error(f"Error in similarity query: {e}") + raise ParseError("Similarity query failed.") all_results = cursor.fetchall() - ids = [row[0] for row in all_results] objs = model.objects.filter(id__in=ids) return objs + + +def get_search_exclusions(q): + """ + To add/modify search exclusion rules edit: services/fixtures/exclusion_rules + To import rules: ./manage.py loaddata services/fixtures/exclusion_rules.json + """ + rule = ExclusionRule.objects.filter(word__iexact=q).first() + if rule: + return rule.exclusion + return "" + + +def has_exclusion_word_in_query(q_vals, language_short): + """ + To add/modify search exclusion words edit: services/fixtures/exclusion_words.json + To import words: ./manage.py loaddata services/fixtures/exclusion_words.json + """ + return ( + ExclusionWord.objects.filter(language_short=language_short) + .annotate(word_lower=Lower("word")) + .filter(word_lower__in=[q.lower() for q in q_vals]) + .exists() + )