diff --git a/datagateway_api/config.json.example b/datagateway_api/config.json.example index 1b86bf57..5e31370b 100644 --- a/datagateway_api/config.json.example +++ b/datagateway_api/config.json.example @@ -15,7 +15,11 @@ "icat_check_cert": false, "mechanism": "anon", "username": "", - "password": "" + "password": "", + "scoring_enabled": false, + "scoring_server": "http://localhost:9000/score", + "scoring_group": "investigation" #corresponds to the defined group in the scoring app. https://github.com/panosc-eu/panosc-search-scoring/blob/master/docs/md/PaNOSC_Federated_Search_Results_Scoring_API.md#model + "scoring_limit": 1000 }, "flask_reloader": false, "log_level": "WARN", diff --git a/datagateway_api/src/api_start_utils.py b/datagateway_api/src/api_start_utils.py index a56e61bc..8a47aa2d 100644 --- a/datagateway_api/src/api_start_utils.py +++ b/datagateway_api/src/api_start_utils.py @@ -287,9 +287,9 @@ def create_api_endpoints(flask_app, api, specs): ) search_api_extension = Config.config.search_api.extension search_api_entity_endpoints = { - "datasets": "Dataset", - "documents": "Document", - "instruments": "Instrument", + "Datasets": "Dataset", + "Documents": "Document", + "Instruments": "Instrument", } for endpoint_name, entity_name in search_api_entity_endpoints.items(): diff --git a/datagateway_api/src/common/config.py b/datagateway_api/src/common/config.py index a23dbb6f..d0f02233 100644 --- a/datagateway_api/src/common/config.py +++ b/datagateway_api/src/common/config.py @@ -133,6 +133,10 @@ class SearchAPI(BaseModel): mechanism: StrictStr username: StrictStr password: StrictStr + scoring_enabled: StrictBool + scoring_server: StrictStr + scoring_group: StrictStr + scoring_limit: StrictInt _validate_extension = validator("extension", allow_reuse=True)(validate_extension) diff --git a/datagateway_api/src/common/exceptions.py b/datagateway_api/src/common/exceptions.py index 87498ef6..9798af64 100644 --- a/datagateway_api/src/common/exceptions.py +++ b/datagateway_api/src/common/exceptions.py @@ -59,3 +59,9 @@ class SearchAPIError(ApiError): def __init__(self, msg="Search API error", *args, **kwargs): super().__init__(msg, *args, **kwargs) self.status_code = 500 + + +class ScoringAPIError(ApiError): + def __init__(self, msg="Scoring API error", *args, **kwargs): + super().__init__(msg, *args, **kwargs) + self.status_code = 500 diff --git a/datagateway_api/src/common/filters.py b/datagateway_api/src/common/filters.py index bd0e636c..1e420520 100644 --- a/datagateway_api/src/common/filters.py +++ b/datagateway_api/src/common/filters.py @@ -81,3 +81,10 @@ class IncludeFilter(QueryFilter): def __init__(self, included_filters): self.included_filters = included_filters + + +class ScoringQueryFilter(QueryFilter): + precedence = 6 + + def __init__(self, value): + self.value = value diff --git a/datagateway_api/src/resources/search_api_endpoints.py b/datagateway_api/src/resources/search_api_endpoints.py index 7c89b89f..987ade08 100644 --- a/datagateway_api/src/resources/search_api_endpoints.py +++ b/datagateway_api/src/resources/search_api_endpoints.py @@ -2,13 +2,18 @@ from flask_restful import Resource +from datagateway_api.src.common.config import Config from datagateway_api.src.common.helpers import get_filters_from_query_string from datagateway_api.src.search_api.helpers import ( + add_scores_to_entities, get_count, get_files, get_files_count, + get_score, get_search, + get_search_api_query_filter_list, get_with_pid, + not_query_filter, search_api_error_handling, ) @@ -30,8 +35,30 @@ class Endpoint(Resource): @search_api_error_handling def get(self): filters = get_filters_from_query_string("search_api", entity_name) - log.debug("Filters: %s", filters) - return get_search(entity_name, filters), 200 + log.debug( + "%s Filters: %s found, entity_name: %s", + len(filters), + filters, + entity_name, + ) + # in case there is no query filter then we processed as usual + if not not_query_filter(filters): + return get_search(entity_name, filters), 200 + else: + query = get_search_api_query_filter_list(filters)[0].value + log.debug("Performing the search") + entities = get_search( + entity_name, + filters, + "LOWER(o.summary) like '%" + query.lower() + "%'", + ) + log.debug( + "Applying score to %s entities with query %s", len(entities), query, + ) + if Config.config.search_api.scoring_enabled: + scores = get_score(entities, query) + entities = add_scores_to_entities(entities, scores) + return entities, 200 get.__doc__ = f""" --- diff --git a/datagateway_api/src/search_api/filters.py b/datagateway_api/src/search_api/filters.py index f832d858..13a7ac03 100644 --- a/datagateway_api/src/search_api/filters.py +++ b/datagateway_api/src/search_api/filters.py @@ -8,6 +8,7 @@ PythonICATLimitFilter, PythonICATSkipFilter, PythonICATWhereFilter, + ScoringQueryFilter, ) from datagateway_api.src.search_api.models import PaNOSCAttribute from datagateway_api.src.search_api.panosc_mappings import mappings @@ -162,6 +163,14 @@ def apply_filter(self, query): return super().apply_filter(query.icat_query.query) +class SearchAPIScoringFilter(ScoringQueryFilter): + def __init__(self, query_value): + super().__init__(query_value) + + def apply_filter(self, query): + return + + class SearchAPIIncludeFilter(PythonICATIncludeFilter): def __init__(self, included_filters, panosc_entity_name): self.included_filters = included_filters @@ -183,6 +192,7 @@ def apply_filter(self, query): panosc_entity_name, icat_field_name = mappings.get_icat_mapping( panosc_entity_name, split_field, ) + split_icat_field_name.append(icat_field_name) icat_field_names.append(".".join(split_icat_field_name)) diff --git a/datagateway_api/src/search_api/helpers.py b/datagateway_api/src/search_api/helpers.py index e4c125ad..b3174f5b 100644 --- a/datagateway_api/src/search_api/helpers.py +++ b/datagateway_api/src/search_api/helpers.py @@ -3,10 +3,13 @@ import logging from pydantic import ValidationError +import requests +from datagateway_api.src.common.config import Config from datagateway_api.src.common.exceptions import ( BadRequestError, MissingRecordError, + ScoringAPIError, SearchAPIError, ) from datagateway_api.src.common.filter_order_handler import FilterOrderHandler @@ -14,6 +17,7 @@ SearchAPIIncludeFilter, SearchAPIWhereFilter, ) +from datagateway_api.src.search_api.filters import SearchAPIScoringFilter import datagateway_api.src.search_api.models as models from datagateway_api.src.search_api.query import SearchAPIQuery from datagateway_api.src.search_api.session_handler import ( @@ -21,7 +25,6 @@ SessionHandler, ) - log = logging.getLogger() @@ -43,6 +46,10 @@ def wrapper_error_handling(*args, **kwargs): log.exception(msg=e.args) assign_status_code(e, 500) raise SearchAPIError(create_error_message(e)) + except ConnectionError as e: + log.exception(msg=e.args) + assign_status_code(e, 500) + raise ScoringAPIError(create_error_message(e)) except (ValueError, TypeError, AttributeError, KeyError) as e: log.exception(msg=e.args) assign_status_code(e, 400) @@ -74,8 +81,71 @@ def create_error_message(e): return wrapper_error_handling +def get_score(entities, query): + """ + Gets the score on the given entities based in the query parameter + that is the term to be found + + :param entities: List of entities that have been retrieved from one ICAT query. + :type entities: :class:`list` + :param query: String with the term to be searched by + :type query: :class:`str` + """ + try: + data = { + "query": query, + "group": Config.config.search_api.scoring_group, + "limit": Config.config.search_api.scoring_limit, + # With itemIds, scoring server returns a 400 error. No idea why. + # "itemIds": list(map(lambda entity: (entity["pid"]), entities)), # + } + response = requests.post( + Config.config.search_api.scoring_server, json=data, timeout=5, + ) + if response.status_code < 400: + scores = response.json()["scores"] + log.debug( + "%s scores out of %s entities retrieved", len(scores), len(entities), + ) + return scores + else: + raise ScoringAPIError( + Exception(f"Score API returned {response.status_code}"), + ) + except ValueError as e: + log.error("Response is not a valid json") + raise e + except ConnectionError as e: + log.error("ConnectionError to %s ", Config.config.search_api.scoring_server) + raise e + except Exception as e: + log.error("Error on scoring") + raise e + + +def add_scores_to_entities(entities, scores): + """ + For each entity this function adds the score if it is found by matching + the score.item.itemsId with the pid of the entity + Otherwise the score is filled with -1 (arbitrarily chosen) + + :param entities: List of entities that have been retrieved from one ICAT query. + :type entities: :class:`list` + :param scores: List of items retrieved from the scoring application + :type scores: :class:`list` + """ + for entity in entities: + entity["score"] = -1 + items = list( + filter(lambda score: str(score["itemId"]) == str(entity["pid"]), scores), + ) + if len(items) == 1: + entity["score"] = items[0]["score"] + return entities + + @client_manager -def get_search(entity_name, filters): +def get_search(entity_name, filters, str_conditions=None): """ Search for data on the given entity, using filters from the request to restrict the query @@ -84,6 +154,8 @@ def get_search(entity_name, filters): :type entity_name: :class:`str` :param filters: The list of Search API filters to be applied to the request/query :type filters: List of specific implementation :class:`QueryFilter` + :param str_conditions: Where clause to be applied to the JPQL query + :type str_conditions: :class:`str` :return: List of records (in JSON serialisable format) of the given entity for the query constructed from that and the request's filters """ @@ -96,8 +168,8 @@ def get_search(entity_name, filters): if isinstance(filter_, SearchAPIIncludeFilter): entity_relations.extend(filter_.included_filters) - query = SearchAPIQuery(entity_name) - + query = SearchAPIQuery(entity_name, str_conditions=str_conditions) + log.debug("Query: %s", query) filter_handler = FilterOrderHandler() filter_handler.add_filters(filters) filter_handler.add_icat_relations_for_panosc_non_related_fields(entity_name) @@ -198,8 +270,11 @@ def get_files(entity_name, pid, filters): """ log.info("Getting files of dataset (PID: %s), using request's filters", pid) - log.debug( - "Entity Name: %s, Filters: %s", entity_name, filters, + log.debug + ( + "Entity Name: %s, Filters: %s", + entity_name, + filters, ) filters.append(SearchAPIWhereFilter("dataset.pid", pid, "eq")) @@ -229,3 +304,18 @@ def get_files_count(entity_name, filters, pid): filters.append(SearchAPIWhereFilter("dataset.pid", pid, "eq")) return get_count(entity_name, filters) + + +def get_search_api_query_filter_list(filters): + """ + Returns the list of SearchAPIQueryFilter that are in the filters array + """ + return list(filter(lambda x: isinstance(x, SearchAPIScoringFilter), filters)) + + +@client_manager +def not_query_filter(filters): + """ + Checks if there is a SearchAPIQueryFilter in the list of filters + """ + return len(get_search_api_query_filter_list(filters)) == 1 diff --git a/datagateway_api/src/search_api/models.py b/datagateway_api/src/search_api/models.py index 2fc1e3ad..346b9065 100644 --- a/datagateway_api/src/search_api/models.py +++ b/datagateway_api/src/search_api/models.py @@ -210,7 +210,7 @@ class Dataset(PaNOSCAttribute): @validator("pid", pre=True, always=True) def set_pid(cls, value): # noqa: B902, N805 - return f"pid:{value}" if isinstance(value, int) else value + return f"{value}" @root_validator(pre=True) def set_is_public(cls, values): # noqa: B902, N805 @@ -250,7 +250,7 @@ class Document(PaNOSCAttribute): @validator("pid", pre=True, always=True) def set_pid(cls, value): # noqa: B902, N805 - return f"pid:{value}" if isinstance(value, int) else value + return value @root_validator(pre=True) def set_is_public(cls, values): # noqa: B902, N805 @@ -296,7 +296,7 @@ class Instrument(PaNOSCAttribute): @validator("pid", pre=True, always=True) def set_pid(cls, value): # noqa: B902, N805 - return f"pid:{value}" if isinstance(value, int) else value + return value @classmethod def from_icat(cls, icat_data, required_related_fields): @@ -399,7 +399,7 @@ class Sample(PaNOSCAttribute): @validator("pid", pre=True, always=True) def set_pid(cls, value): # noqa: B902, N805 - return f"pid:{value}" if isinstance(value, int) else value + return value @classmethod def from_icat(cls, icat_data, required_related_fields): @@ -419,7 +419,7 @@ class Technique(PaNOSCAttribute): @validator("pid", pre=True, always=True) def set_pid(cls, value): # noqa: B902, N805 - return f"pid:{value}" if isinstance(value, int) else value + return value @classmethod def from_icat(cls, icat_data, required_related_fields): diff --git a/datagateway_api/src/search_api/panosc_mappings.py b/datagateway_api/src/search_api/panosc_mappings.py index 695c0245..25d0bdae 100644 --- a/datagateway_api/src/search_api/panosc_mappings.py +++ b/datagateway_api/src/search_api/panosc_mappings.py @@ -43,14 +43,13 @@ def get_icat_mapping(self, panosc_entity_name, field_name): :raises FilterError: If a valid mapping cannot be found """ - log.info( - "Searching mapping file to find ICAT translation for %s", - f"{panosc_entity_name}.{field_name}", - ) + # log.debug( + # "Searching mapping file to find ICAT translation for %s", + # f"{panosc_entity_name}.{field_name}", + # ) try: icat_mapping = self.mappings[panosc_entity_name][field_name] - log.debug("ICAT mapping/translation found: %s", icat_mapping) except KeyError as e: raise FilterError(f"Bad PaNOSC to ICAT mapping: {e.args}") diff --git a/datagateway_api/src/search_api/query.py b/datagateway_api/src/search_api/query.py index e5a06b2c..4fdcea7d 100644 --- a/datagateway_api/src/search_api/query.py +++ b/datagateway_api/src/search_api/query.py @@ -11,7 +11,6 @@ def __init__(self, panosc_entity_name, **kwargs): self.icat_entity_name = mappings.mappings[panosc_entity_name][ "base_icat_entity" ] - self.icat_query = SearchAPIICATQuery( SessionHandler.client, self.icat_entity_name, **kwargs, ) diff --git a/datagateway_api/src/search_api/query_filter_factory.py b/datagateway_api/src/search_api/query_filter_factory.py index fd669d63..b7f51491 100644 --- a/datagateway_api/src/search_api/query_filter_factory.py +++ b/datagateway_api/src/search_api/query_filter_factory.py @@ -6,6 +6,7 @@ from datagateway_api.src.search_api.filters import ( SearchAPIIncludeFilter, SearchAPILimitFilter, + SearchAPIScoringFilter, SearchAPISkipFilter, SearchAPIWhereFilter, ) @@ -58,7 +59,12 @@ def get_query_filter(request_filter, entity_name=None, related_entity_name=None) ) elif filter_name == "limit": log.info("limit JSON object found") - query_filters.append(SearchAPILimitFilter(filter_input)) + query_filters.append(SearchAPILimitFilter(int(filter_input))) + + elif filter_name == "query": + log.info("query JSON object found") + query_filters.append(SearchAPIScoringFilter(filter_input)) + elif filter_name == "skip": log.info("skip JSON object found") query_filters.append(SearchAPISkipFilter(filter_input)) diff --git a/datagateway_api/src/swagger/search_api/openapi.yaml b/datagateway_api/src/swagger/search_api/openapi.yaml index 7df70e9a..60aaa0f4 100644 --- a/datagateway_api/src/swagger/search_api/openapi.yaml +++ b/datagateway_api/src/swagger/search_api/openapi.yaml @@ -469,7 +469,7 @@ info: version: '1.0' openapi: 3.0.3 paths: - /search-api/datasets: + /search-api/Datasets: get: description: Retrieves a list of Dataset objects parameters: @@ -490,7 +490,7 @@ paths: summary: Get Datasets tags: - Dataset - /search-api/datasets/{pid}: + /search-api/Datasets/{pid}: get: description: Retrieves a Dataset object with the matching pid parameters: @@ -516,7 +516,7 @@ paths: summary: Find the Dataset matching the given pid tags: - Dataset - /search-api/datasets/count: + /search-api/Datasets/count: get: description: Return the count of the Dataset objects that would be retrieved given the filters provided @@ -536,7 +536,7 @@ paths: summary: Count Datasets tags: - Dataset - /search-api/documents: + /search-api/Documents: get: description: Retrieves a list of Document objects parameters: @@ -557,7 +557,7 @@ paths: summary: Get Documents tags: - Document - /search-api/documents/{pid}: + /search-api/Documents/{pid}: get: description: Retrieves a Document object with the matching pid parameters: @@ -583,7 +583,7 @@ paths: summary: Find the Document matching the given pid tags: - Document - /search-api/documents/count: + /search-api/Documents/count: get: description: Return the count of the Document objects that would be retrieved given the filters provided @@ -603,7 +603,7 @@ paths: summary: Count Documents tags: - Document - /search-api/instruments: + /search-api/Instruments: get: description: Retrieves a list of Instrument objects parameters: @@ -624,7 +624,7 @@ paths: summary: Get Instruments tags: - Instrument - /search-api/instruments/{pid}: + /search-api/Instruments/{pid}: get: description: Retrieves a Instrument object with the matching pid parameters: @@ -650,7 +650,7 @@ paths: summary: Find the Instrument matching the given pid tags: - Instrument - /search-api/instruments/count: + /search-api/Instruments/count: get: description: Return the count of the Instrument objects that would be retrieved given the filters provided