Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: Support query filter and scoring #366

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion datagateway_api/config.json.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
"icat_check_cert": false,
"mechanism": "anon",
"username": "",
"password": ""
"password": "",
"scoring_enabled": false,
"scoring_server": "http://localhost:9000/score",
"scoring_group": "investigation" #corresponds to the defined group in the scoring app. https://github.com/panosc-eu/panosc-search-scoring/blob/master/docs/md/PaNOSC_Federated_Search_Results_Scoring_API.md#model
"scoring_limit": 1000
},
"flask_reloader": false,
"log_level": "WARN",
Expand Down
6 changes: 3 additions & 3 deletions datagateway_api/src/api_start_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,9 +287,9 @@ def create_api_endpoints(flask_app, api, specs):
)
search_api_extension = Config.config.search_api.extension
search_api_entity_endpoints = {
"datasets": "Dataset",
"documents": "Document",
"instruments": "Instrument",
"Datasets": "Dataset",
"Documents": "Document",
"Instruments": "Instrument",
antolinos marked this conversation as resolved.
Show resolved Hide resolved
}

for endpoint_name, entity_name in search_api_entity_endpoints.items():
Expand Down
4 changes: 4 additions & 0 deletions datagateway_api/src/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class SearchAPI(BaseModel):
mechanism: StrictStr
username: StrictStr
password: StrictStr
scoring_enabled: StrictBool
scoring_server: StrictStr
scoring_group: StrictStr
Comment on lines +136 to +138
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add example values to config.json.example please? Keeping the scoring disabled might be best in case of someone using DataGateway API only, not the search API.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my config:

    "scoring_enabled": true,
    "scoring_server": "http://dau-dm-01:9000/score?limit=2000",
    "scoring_group": "investigation"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is your own scoring server? Does this mean I'll need to setup my own instance to test your changes? Is there a PaNOSC scoring server that I could use?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is not as far as I know and take into account that you need to install and populate with your own data in order to calculate the weights.
It is not a big deal but has to be done.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's good to know, I will setup my own and test this branch when I get a chance

scoring_limit: StrictInt

_validate_extension = validator("extension", allow_reuse=True)(validate_extension)

Expand Down
6 changes: 6 additions & 0 deletions datagateway_api/src/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,9 @@ class SearchAPIError(ApiError):
def __init__(self, msg="Search API error", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
self.status_code = 500


class ScoringAPIError(ApiError):
def __init__(self, msg="Scoring API error", *args, **kwargs):
super().__init__(msg, *args, **kwargs)
self.status_code = 500
7 changes: 7 additions & 0 deletions datagateway_api/src/common/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,10 @@ class IncludeFilter(QueryFilter):

def __init__(self, included_filters):
self.included_filters = included_filters


class ScoringQueryFilter(QueryFilter):
precedence = 6

def __init__(self, value):
self.value = value
31 changes: 29 additions & 2 deletions datagateway_api/src/resources/search_api_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

from flask_restful import Resource

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.helpers import get_filters_from_query_string
from datagateway_api.src.search_api.helpers import (
add_scores_to_entities,
get_count,
get_files,
get_files_count,
get_score,
get_search,
get_search_api_query_filter_list,
get_with_pid,
not_query_filter,
search_api_error_handling,
)

Expand All @@ -30,8 +35,30 @@ class Endpoint(Resource):
@search_api_error_handling
def get(self):
filters = get_filters_from_query_string("search_api", entity_name)
log.debug("Filters: %s", filters)
return get_search(entity_name, filters), 200
log.debug(
"%s Filters: %s found, entity_name: %s",
len(filters),
filters,
entity_name,
)
# in case there is no query filter then we processed as usual
if not not_query_filter(filters):
return get_search(entity_name, filters), 200
else:
query = get_search_api_query_filter_list(filters)[0].value
log.debug("Performing the search")
entities = get_search(
entity_name,
filters,
"LOWER(o.summary) like '%" + query.lower() + "%'",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this line and line 56 (where you pass "investigations" to get_score()), I guess this only works on /documents for now? Have you got a plan of how to make it work for the other endpoints?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right and honestly I do not know. There are several reason:

  1. I focused on making https://data.panosc.eu/ to work. Surprisingly I did discover that only use /documents endpoint. So, I did not consider to add something that is not used
  2. When I tried to use the scoring app with the datasets, it did not work when calculating the score (my guess is because of the number of datasets). I created an issue:
    Compute gets stuck panosc-eu/panosc-search-scoring#9

So, yes, it might be needed in the future but it very uncertain and first we would need the scoring app to work at the level of datasets

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were slightly concerned about the way the score calculation behaves and that it might cause performance issues with large volumes of data. It would be good to confirm (from someone who wrote the scoring software perhaps?) which endpoint(s) the scoring needs to work on.

)
log.debug(
"Applying score to %s entities with query %s", len(entities), query,
)
if Config.config.search_api.scoring_enabled:
scores = get_score(entities, query)
entities = add_scores_to_entities(entities, scores)
return entities, 200

get.__doc__ = f"""
---
Expand Down
10 changes: 10 additions & 0 deletions datagateway_api/src/search_api/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
PythonICATLimitFilter,
PythonICATSkipFilter,
PythonICATWhereFilter,
ScoringQueryFilter,
)
from datagateway_api.src.search_api.models import PaNOSCAttribute
from datagateway_api.src.search_api.panosc_mappings import mappings
Expand Down Expand Up @@ -162,6 +163,14 @@ def apply_filter(self, query):
return super().apply_filter(query.icat_query.query)


class SearchAPIScoringFilter(ScoringQueryFilter):
def __init__(self, query_value):
super().__init__(query_value)

def apply_filter(self, query):
return


class SearchAPIIncludeFilter(PythonICATIncludeFilter):
def __init__(self, included_filters, panosc_entity_name):
self.included_filters = included_filters
Expand All @@ -183,6 +192,7 @@ def apply_filter(self, query):
panosc_entity_name, icat_field_name = mappings.get_icat_mapping(
panosc_entity_name, split_field,
)

split_icat_field_name.append(icat_field_name)

icat_field_names.append(".".join(split_icat_field_name))
Expand Down
102 changes: 96 additions & 6 deletions datagateway_api/src/search_api/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,28 @@
import logging

from pydantic import ValidationError
import requests

from datagateway_api.src.common.config import Config
from datagateway_api.src.common.exceptions import (
BadRequestError,
MissingRecordError,
ScoringAPIError,
SearchAPIError,
)
from datagateway_api.src.common.filter_order_handler import FilterOrderHandler
from datagateway_api.src.search_api.filters import (
SearchAPIIncludeFilter,
SearchAPIWhereFilter,
)
from datagateway_api.src.search_api.filters import SearchAPIScoringFilter
import datagateway_api.src.search_api.models as models
from datagateway_api.src.search_api.query import SearchAPIQuery
from datagateway_api.src.search_api.session_handler import (
client_manager,
SessionHandler,
)


log = logging.getLogger()


Expand All @@ -43,6 +46,10 @@ def wrapper_error_handling(*args, **kwargs):
log.exception(msg=e.args)
assign_status_code(e, 500)
raise SearchAPIError(create_error_message(e))
except ConnectionError as e:
log.exception(msg=e.args)
assign_status_code(e, 500)
raise ScoringAPIError(create_error_message(e))
except (ValueError, TypeError, AttributeError, KeyError) as e:
log.exception(msg=e.args)
assign_status_code(e, 400)
Expand Down Expand Up @@ -74,8 +81,71 @@ def create_error_message(e):
return wrapper_error_handling


def get_score(entities, query):
"""
Gets the score on the given entities based in the query parameter
that is the term to be found

:param entities: List of entities that have been retrieved from one ICAT query.
:type entities: :class:`list`
:param query: String with the term to be searched by
:type query: :class:`str`
"""
try:
data = {
"query": query,
"group": Config.config.search_api.scoring_group,
"limit": Config.config.search_api.scoring_limit,
# With itemIds, scoring server returns a 400 error. No idea why.
# "itemIds": list(map(lambda entity: (entity["pid"]), entities)), #
}
response = requests.post(
Config.config.search_api.scoring_server, json=data, timeout=5,
)
if response.status_code < 400:
scores = response.json()["scores"]
log.debug(
"%s scores out of %s entities retrieved", len(scores), len(entities),
)
return scores
else:
raise ScoringAPIError(
Exception(f"Score API returned {response.status_code}"),
)
except ValueError as e:
log.error("Response is not a valid json")
raise e
except ConnectionError as e:
log.error("ConnectionError to %s ", Config.config.search_api.scoring_server)
raise e
except Exception as e:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there any other specific exceptions that we could catch, instead of a generic Exception?

log.error("Error on scoring")
raise e


def add_scores_to_entities(entities, scores):
"""
For each entity this function adds the score if it is found by matching
the score.item.itemsId with the pid of the entity
Otherwise the score is filled with -1 (arbitrarily chosen)

:param entities: List of entities that have been retrieved from one ICAT query.
:type entities: :class:`list`
:param scores: List of items retrieved from the scoring application
:type scores: :class:`list`
"""
for entity in entities:
entity["score"] = -1
items = list(
filter(lambda score: str(score["itemId"]) == str(entity["pid"]), scores),
)
if len(items) == 1:
entity["score"] = items[0]["score"]
return entities


@client_manager
def get_search(entity_name, filters):
def get_search(entity_name, filters, str_conditions=None):
"""
Search for data on the given entity, using filters from the request to restrict the
query
Expand All @@ -84,6 +154,8 @@ def get_search(entity_name, filters):
:type entity_name: :class:`str`
:param filters: The list of Search API filters to be applied to the request/query
:type filters: List of specific implementation :class:`QueryFilter`
:param str_conditions: Where clause to be applied to the JPQL query
:type str_conditions: :class:`str`
:return: List of records (in JSON serialisable format) of the given entity for the
query constructed from that and the request's filters
"""
Expand All @@ -96,8 +168,8 @@ def get_search(entity_name, filters):
if isinstance(filter_, SearchAPIIncludeFilter):
entity_relations.extend(filter_.included_filters)

query = SearchAPIQuery(entity_name)

query = SearchAPIQuery(entity_name, str_conditions=str_conditions)
log.debug("Query: %s", query)
filter_handler = FilterOrderHandler()
filter_handler.add_filters(filters)
filter_handler.add_icat_relations_for_panosc_non_related_fields(entity_name)
Expand Down Expand Up @@ -198,8 +270,11 @@ def get_files(entity_name, pid, filters):
"""

log.info("Getting files of dataset (PID: %s), using request's filters", pid)
log.debug(
"Entity Name: %s, Filters: %s", entity_name, filters,
log.debug
(
"Entity Name: %s, Filters: %s",
entity_name,
filters,
Comment on lines +273 to +277
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for me: formatting change within a file that has functionality changes

)

filters.append(SearchAPIWhereFilter("dataset.pid", pid, "eq"))
Expand Down Expand Up @@ -229,3 +304,18 @@ def get_files_count(entity_name, filters, pid):

filters.append(SearchAPIWhereFilter("dataset.pid", pid, "eq"))
return get_count(entity_name, filters)


def get_search_api_query_filter_list(filters):
"""
Returns the list of SearchAPIQueryFilter that are in the filters array
"""
return list(filter(lambda x: isinstance(x, SearchAPIScoringFilter), filters))


@client_manager
def not_query_filter(filters):
"""
Checks if there is a SearchAPIQueryFilter in the list of filters
"""
return len(get_search_api_query_filter_list(filters)) == 1
10 changes: 5 additions & 5 deletions datagateway_api/src/search_api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ class Dataset(PaNOSCAttribute):

@validator("pid", pre=True, always=True)
def set_pid(cls, value): # noqa: B902, N805
return f"pid:{value}" if isinstance(value, int) else value
return f"{value}"

@root_validator(pre=True)
def set_is_public(cls, values): # noqa: B902, N805
Expand Down Expand Up @@ -250,7 +250,7 @@ class Document(PaNOSCAttribute):

@validator("pid", pre=True, always=True)
def set_pid(cls, value): # noqa: B902, N805
return f"pid:{value}" if isinstance(value, int) else value
return value

@root_validator(pre=True)
def set_is_public(cls, values): # noqa: B902, N805
Expand Down Expand Up @@ -296,7 +296,7 @@ class Instrument(PaNOSCAttribute):

@validator("pid", pre=True, always=True)
def set_pid(cls, value): # noqa: B902, N805
return f"pid:{value}" if isinstance(value, int) else value
return value

@classmethod
def from_icat(cls, icat_data, required_related_fields):
Expand Down Expand Up @@ -399,7 +399,7 @@ class Sample(PaNOSCAttribute):

@validator("pid", pre=True, always=True)
def set_pid(cls, value): # noqa: B902, N805
return f"pid:{value}" if isinstance(value, int) else value
return value

@classmethod
def from_icat(cls, icat_data, required_related_fields):
Expand All @@ -419,7 +419,7 @@ class Technique(PaNOSCAttribute):

@validator("pid", pre=True, always=True)
def set_pid(cls, value): # noqa: B902, N805
return f"pid:{value}" if isinstance(value, int) else value
return value

@classmethod
def from_icat(cls, icat_data, required_related_fields):
Expand Down
9 changes: 4 additions & 5 deletions datagateway_api/src/search_api/panosc_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ def get_icat_mapping(self, panosc_entity_name, field_name):
:raises FilterError: If a valid mapping cannot be found
"""

log.info(
"Searching mapping file to find ICAT translation for %s",
f"{panosc_entity_name}.{field_name}",
)
# log.debug(
# "Searching mapping file to find ICAT translation for %s",
# f"{panosc_entity_name}.{field_name}",
# )

try:
icat_mapping = self.mappings[panosc_entity_name][field_name]
log.debug("ICAT mapping/translation found: %s", icat_mapping)
except KeyError as e:
raise FilterError(f"Bad PaNOSC to ICAT mapping: {e.args}")

Expand Down
1 change: 0 additions & 1 deletion datagateway_api/src/search_api/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(self, panosc_entity_name, **kwargs):
self.icat_entity_name = mappings.mappings[panosc_entity_name][
"base_icat_entity"
]

self.icat_query = SearchAPIICATQuery(
SessionHandler.client, self.icat_entity_name, **kwargs,
)
Expand Down
Loading