Skip to content

Commit

Permalink
MAST query result cache: Observations.query_criteria()
Browse files Browse the repository at this point in the history
  • Loading branch information
orionlee authored and ceb8 committed Sep 27, 2022
1 parent 62a34f0 commit d1a8f66
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 12 deletions.
55 changes: 46 additions & 9 deletions astroquery/mast/discovery_portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import uuid
import json
import time
import re

import numpy as np

Expand All @@ -18,7 +19,7 @@
from astropy.table import Table, vstack, MaskedColumn
from astropy.utils import deprecated

from ..query import BaseQuery, QueryWithLogin
from ..query import BaseQuery, QueryWithLogin, AstroQuery, to_cache
from ..utils import async_to_sync
from ..utils.class_or_instance import class_or_instance
from ..exceptions import InputWarning, NoResultsWarning, RemoteServiceError
Expand Down Expand Up @@ -211,7 +212,39 @@ def _request(self, method, url, params=None, data=None, headers=None,

return all_responses

def _get_col_config(self, service, fetch_name=None):
def _request_w_cache(self, method, url, data=None, headers=None, retrieve_all=True,
cache=False, cache_opts=None):
# Note: the method only exposes 4 parameters of the underlying _request() function
# to play nice with existing mocks
# Caching: follow BaseQuery._request()'s pattern, which uses an AstroQuery object
if not cache:
response = self._request(method, url, data=data, headers=headers, retrieve_all=retrieve_all)
else:
cacher = self._get_cacher(method, url, data, headers, retrieve_all)
response = cacher.from_cache(self.cache_location)
if not response:
response = self._request(method, url, data=data, headers=headers, retrieve_all=retrieve_all)
to_cache(response, cacher.request_file(self.cache_location))
return response

def _get_cacher(self, method, url, data, headers, retrieve_all):
"""
Return an object that can cache the HTTP request based on the supplied arguments
"""

# cacheBreaker parameter (to underlying MAST service) is not relevant (and breaks) local caching
# remove it from part of the cache key
data_no_cache_breaker = re.sub(r'^(.+)cacheBreaker%22%3A%20%22.+%22', r'\1', data)
# include retrieve_all as part of the cache key by appending it to data
# it cannot be added as part of req_kwargs dict, as it will be rejected by AstroQuery
data_w_retrieve_all = data_no_cache_breaker + " retrieve_all={}".format(retrieve_all)
req_kwargs = dict(
data=data_no_cache_breaker,
headers=headers
)
return AstroQuery(method, url, **req_kwargs)

def _get_col_config(self, service, fetch_name=None, cache=False):
"""
Gets the columnsConfig entry for given service and stores it in `self._column_configs`.
Expand Down Expand Up @@ -247,7 +280,7 @@ def _get_col_config(self, service, fetch_name=None):
if more:
mashup_request = {'service': all_name, 'params': {}, 'format': 'extjs'}
req_string = _prepare_service_request_string(mashup_request)
response = self._request("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers)
response = self._request_w_cache("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers, cache=cache)
json_response = response[0].json()

self._column_configs[service].update(json_response['data']['Tables'][0]
Expand Down Expand Up @@ -301,7 +334,7 @@ def _parse_result(self, responses, verbose=False):
return all_results

@class_or_instance
def service_request_async(self, service, params, pagesize=None, page=None, **kwargs):
def service_request_async(self, service, params, pagesize=None, page=None, cache=False, cache_opts=None, **kwargs):
"""
Given a Mashup service and parameters, builds and excecutes a Mashup query.
See documentation `here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
Expand All @@ -321,6 +354,10 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
Default None.
Can be used to override the default behavior of all results being returned to obtain
a specific page of results.
cache : Boolean, optional
try to use cached the query result if set to True
cache_opts : dict, optional
cache options, details TBD, e.g., cache expiration policy, etc.
**kwargs :
See MashupRequest properties
`here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
Expand All @@ -334,7 +371,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
# setting self._current_service
if service not in self._column_configs.keys():
fetch_name = kwargs.pop('fetch_name', None)
self._get_col_config(service, fetch_name)
self._get_col_config(service, fetch_name, cache)
self._current_service = service

# setting up pagination
Expand All @@ -360,12 +397,12 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
mashup_request[prop] = value

req_string = _prepare_service_request_string(mashup_request)
response = self._request("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers,
retrieve_all=retrieve_all)
response = self._request_w_cache("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers,
retrieve_all=retrieve_all, cache=cache, cache_opts=cache_opts)

return response

def build_filter_set(self, column_config_name, service_name=None, **filters):
def build_filter_set(self, column_config_name, service_name=None, cache=False, **filters):
"""
Takes user input dictionary of filters and returns a filterlist that the Mashup can understand.
Expand Down Expand Up @@ -393,7 +430,7 @@ def build_filter_set(self, column_config_name, service_name=None, **filters):
service_name = column_config_name

if not self._column_configs.get(service_name):
self._get_col_config(service_name, fetch_name=column_config_name)
self._get_col_config(service_name, fetch_name=column_config_name, cache=cache)

caom_col_config = self._column_configs[service_name]

Expand Down
4 changes: 2 additions & 2 deletions astroquery/mast/observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def query_object_async(self, objectname, *, radius=0.2*u.deg, pagesize=None, pag
return self.query_region_async(coordinates, radius=radius, pagesize=pagesize, page=page)

@class_or_instance
def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
def query_criteria_async(self, *, pagesize=None, page=None, cache=False, cache_opts=None, **criteria):
"""
Given an set of criteria, returns a list of MAST observations.
Valid criteria are returned by ``get_metadata("observations")``
Expand Down Expand Up @@ -300,7 +300,7 @@ def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
params = {"columns": "*",
"filters": mashup_filters}

return self._portal_api_connection.service_request_async(service, params)
return self._portal_api_connection.service_request_async(service, params, cache=cache, cache_opts=cache_opts)

def query_region_count(self, coordinates, *, radius=0.2*u.deg, pagesize=None, page=None):
"""
Expand Down
50 changes: 50 additions & 0 deletions astroquery/mast/tests/test_mast.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,25 @@ def patch_post(request):
return mp


_num_mockreturn = 0


def _get_num_mockreturn():
global _num_mockreturn
return _num_mockreturn


def _reset_mockreturn_counter():
global _num_mockreturn
_num_mockreturn = 0


def _inc_num_mockreturn():
global _num_mockreturn
_num_mockreturn += 1
return _num_mockreturn


def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwargs):
if "columnsconfig" in url:
if "Mast.Catalogs.Tess.Cone" in data:
Expand All @@ -102,6 +121,9 @@ def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwar
with open(filename, 'rb') as infile:
content = infile.read()

# For cache tests
_inc_num_mockreturn()

# returning as list because this is what the mast _request function does
return [MockResponse(content)]

Expand Down Expand Up @@ -367,6 +389,34 @@ def test_query_observations_criteria_async(patch_post):
assert isinstance(responses, list)


def test_query_observations_criteria_async_cache(patch_post):
_reset_mockreturn_counter()
assert 0 == _get_num_mockreturn(), "Mock HTTP call counter reset to 0"

responses_cache_miss = mast.Observations.query_criteria_async(dataproduct_type=["image"],
proposal_pi="Ost*",
s_dec=[43.5, 45.5], cache=True)
assert isinstance(responses_cache_miss, list)
num_mockreturn_after_first_call = _get_num_mockreturn()
assert num_mockreturn_after_first_call > 0, "Cache miss, some underlying HTTP call"

responses_cache_hit = mast.Observations.query_criteria_async(dataproduct_type=["image"],
proposal_pi="Ost*",
s_dec=[43.5, 45.5], cache=True)
# assert the cached response is the same
assert len(responses_cache_hit) == len(responses_cache_miss)
assert responses_cache_hit[0].text == responses_cache_miss[0].text
# ensure the response really comes from the cache
assert num_mockreturn_after_first_call == _get_num_mockreturn(), \
'Cache hit: should reach cache only, i.e., no HTTP call'

responses_no_cache = mast.Observations.query_criteria_async(dataproduct_type=["image"],
proposal_pi="Ost*",
s_dec=[43.5, 45.5], cache=False)
assert isinstance(responses_no_cache, list)
assert _get_num_mockreturn() > num_mockreturn_after_first_call, "Cache off , some underlying HTTP call"


def test_observations_query_criteria(patch_post):
# without position
result = mast.Observations.query_criteria(dataproduct_type=["image"],
Expand Down
3 changes: 2 additions & 1 deletion astroquery/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def from_cache(self, cache_location):
try:
with open(request_file, "rb") as f:
response = pickle.load(f)
if not isinstance(response, requests.Response):
if not isinstance(response, requests.Response) and not isinstance(response, list):
# MAST query response is a list of Response
response = None
except FileNotFoundError:
response = None
Expand Down

0 comments on commit d1a8f66

Please sign in to comment.