diff --git a/services/api.py b/services/api.py index f45a26f4..6535fe92 100644 --- a/services/api.py +++ b/services/api.py @@ -991,6 +991,7 @@ def render(self, data, media_type=None, renderer_context=None): LEVEL_PARAMETER, UNIT_GEOMETRY_PARAMETER, UNIT_GEOMETRY_3D_PARAMETER, + BBOX_PARAMETER, ] ) class UnitViewSet( diff --git a/services/search/api.py b/services/search/api.py index 9d9bd1bf..76d2e39d 100644 --- a/services/search/api.py +++ b/services/search/api.py @@ -22,11 +22,13 @@ import re from itertools import chain +from django.contrib.gis.gdal import SpatialReference from django.db import connection, reset_queries from django.db.models import Count from drf_spectacular.utils import extend_schema, OpenApiParameter from munigeo import api as munigeo_api from munigeo.models import Address, AdministrativeDivision +from munigeo.utils import get_default_srid from rest_framework import serializers, status from rest_framework.exceptions import ParseError from rest_framework.generics import GenericAPIView @@ -344,6 +346,13 @@ def to_representation(self, obj): required=False, type=str, ), + OpenApiParameter( + name="bbox", + location=OpenApiParameter.QUERY, + description="Bounding box in the format 'left,bottom,right,top'.", + required=False, + type=str, + ), ], description="Search for units, services, service nodes, addresses and administrative divisions.", ) @@ -587,6 +596,16 @@ def get(self, request): if services[0]: units_qs = units_qs.filter(services__in=services) + if "bbox" in self.request.query_params: + bbox = self.request.query_params["bbox"] + if "bbox_srid" in self.request.query_params: + bbox_srid = self.request.query_params["bbox_srid"] + else: + bbox_srid = get_default_srid() + ref = SpatialReference(bbox_srid) + bbox_filter = munigeo_api.build_bbox_filter(ref, bbox, "location") + units_qs = units_qs.filter(**bbox_filter) + if units_order_list: units_qs = units_qs.annotate(num_services=Count("services")).order_by( *units_order_list diff --git a/services/search/tests/conftest.py b/services/search/tests/conftest.py index 9cfa6167..8bd95434 100644 --- a/services/search/tests/conftest.py +++ b/services/search/tests/conftest.py @@ -96,6 +96,7 @@ def units( last_modified_time=now(), municipality=municipality, department=department, + location=Point(24.941387, 60.17103, srid=4326), # Helsinki center ) # Add service Halli unit.services.add(4) diff --git a/services/search/tests/test_api.py b/services/search/tests/test_api.py index 37665365..7e1fabef 100644 --- a/services/search/tests/test_api.py +++ b/services/search/tests/test_api.py @@ -264,3 +264,23 @@ def test_search_with_vertical_bar_in_query(api_client, units): url = reverse("search") + "?q=|terveysasema||''||'" response = api_client.get(url) assert response.status_code == 200 + + +@pytest.mark.django_db +def test_search_with_bbox_parameter(api_client, units): + """ + When bbox parameter is given, only units within the bounding box should be returned. + """ + url = reverse("search") + "?q=halli&type=unit" + response = api_client.get(url) + results = response.json()["results"] + assert len(results) == 3 + + url = ( + reverse("search") + + "?q=halli&type=unit&bbox=24.93545,60.16952,24.95190,60.17800&bbox_srid=4326" + ) + response = api_client.get(url) + results = response.json()["results"] + assert len(results) == 1 + assert results[0]["name"]["fi"] == "Jäähalli"