Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provider-specific HTTP parameters #63

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions geospaas_processing/downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_auth(cls, kwargs):
return (None, None)

@classmethod
def connect(cls, url, auth=(None, None)):
def connect(cls, url, auth=(None, None), **kwargs):
"""Connect to the remote repository. This should return an
object from which the file to download can be read.
"""
Expand All @@ -83,7 +83,7 @@ def close_connection(cls, connection):
connection.close()

@classmethod
def get_file_name(cls, url, auth):
def get_file_name(cls, url, auth, **kwargs):
"""Returns the name of the file"""
raise NotImplementedError()

Expand All @@ -108,14 +108,14 @@ def check_and_download_url(cls, url, download_dir, **kwargs):
"""
auth = cls.get_auth(kwargs)

file_name = cls.get_file_name(url, auth)
file_name = cls.get_file_name(url, auth, **kwargs)
if not file_name:
raise DownloadError(f"Could not find file name for '{url}'")
file_path = os.path.join(download_dir, file_name)
if os.path.exists(file_path) and os.path.isfile(file_path):
return file_name, False

connection = cls.connect(url, auth)
connection = cls.connect(url, auth, **kwargs)
try:
file_size = cls.get_file_size(url, connection)
if file_size:
Expand Down Expand Up @@ -174,12 +174,22 @@ def get_auth(cls, kwargs):
return super().get_auth(kwargs)

@classmethod
def get_file_name(cls, url, auth):
def get_request_parameters(cls, kwargs):
parameters = kwargs.get('request_parameters', {})
if isinstance(parameters, dict):
return parameters
else:
raise ValueError(
"The 'request_parameters' configuration key should contain a dictionary")

@classmethod
def get_file_name(cls, url, auth, **kwargs):
"""Extracts the file name from the Content-Disposition header
of an HTTP response
"""
try:
response = utils.http_request('HEAD', url, auth=auth)
response = utils.http_request(
'HEAD', url, auth=auth, params=cls.get_request_parameters(kwargs))
response.raise_for_status()
except requests.RequestException:
try:
Expand Down Expand Up @@ -213,13 +223,14 @@ def get_file_name(cls, url, auth):
return ''

@classmethod
def connect(cls, url, auth=(None, None)):
def connect(cls, url, auth=(None, None), **kwargs):
"""For HTTP downloads, the "connection" actually just consists
of sending a GET request to the download URL and return the
corresponding Response object
"""
try:
response = utils.http_request('GET', url, stream=True, auth=auth)
response = utils.http_request(
'GET', url, stream=True, auth=auth, params=cls.get_request_parameters(kwargs))
response.raise_for_status()
# Raising DownloadError enables to display a clear message in the API response
except requests.HTTPError as error:
Expand Down Expand Up @@ -272,7 +283,7 @@ class FTPDownloader(Downloader):
"""Downloader for FTP repositories"""

@classmethod
def connect(cls, url, auth=(None, None)):
def connect(cls, url, auth=(None, None), **kwargs):
"""Connects to the remote FTP repository.
Returns a ftplib.FTP object.
"""
Expand All @@ -284,7 +295,7 @@ def connect(cls, url, auth=(None, None)):
raise DownloadError(f"Could not download from '{url}': {error.args}") from error

@classmethod
def get_file_name(cls, url, auth):
def get_file_name(cls, url, auth, **kwargs):
"""Extracts the file name from the URL"""
return urlparse(url).path.split('/')[-1] or None

Expand Down Expand Up @@ -312,15 +323,15 @@ def get_auth(kwargs):
return (None, None)

@classmethod
def connect(cls, url, auth=(None, None)):
def connect(cls, url, auth=(None, None), **kwargs):
return None

@classmethod
def close_connection(cls, connection):
return None

@classmethod
def get_file_name(cls, url, auth):
def get_file_name(cls, url, auth, **kwargs):
return os.path.basename(url)

@classmethod
Expand Down
3 changes: 3 additions & 0 deletions geospaas_processing/provider_settings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,7 @@
'https://podaac-tools.jpl.nasa.gov/drive/files':
username: !ENV 'PODAAC_DRIVE_USERNAME'
password: !ENV 'PODAAC_DRIVE_PASSWORD'
'https://oceandata.sci.gsfc.nasa.gov':
request_parameters:
appkey: !ENV EARTHDATA_APPKEY
...
69 changes: 67 additions & 2 deletions tests/test_downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ class TestDownloader(downloaders.Downloader):
"""

@classmethod
def connect(cls, url, auth=(None, None)):
def connect(cls, url, auth=(None, None), **kwargs):
return mock.Mock()

@classmethod
def get_file_name(cls, url, connection):
def get_file_name(cls, url, auth, **kwargs):
return 'test_file.txt'

@classmethod
Expand Down Expand Up @@ -223,6 +223,27 @@ def test_get_basic_auth(self):
('username', 'password')
)

def test_get_request_parameters(self):
"""get_request_parameters() should return the
'request_parameters' key from kwargs if it is present and
contains a dictionary
"""
self.assertDictEqual(
downloaders.HTTPDownloader.get_request_parameters({
'foo': 'bar',
'request_parameters': {'baz': 'qux'}
}),
{'baz': 'qux'})

self.assertDictEqual(downloaders.HTTPDownloader.get_request_parameters({'foo': 'bar'}), {})

def test_get_request_parameters_invalid(self):
"""get_request_parameters() should raise an exception if the
'request_parameters' key's contents are invalid
"""
with self.assertRaises(ValueError):
downloaders.HTTPDownloader.get_request_parameters({'request_parameters': 'foo'})

def test_get_file_name(self):
"""Test the correct extraction of a file name from a standard
Content-Disposition header
Expand Down Expand Up @@ -288,6 +309,36 @@ def test_get_file_name_head_error(self):
with self.assertLogs(downloaders.LOGGER, level=logging.ERROR):
self.assertEqual(downloaders.HTTPDownloader.get_file_name('url', None), '')

def test_get_file_name_with_parameters(self):
"""Test getting a file name with parameters in the HEAD request
"""
file_name = "test_file.txt"
response = requests.Response()
response.status_code = 200
response.headers['Content-Disposition'] = f'inline;filename="{file_name}"'
with mock.patch(
'geospaas_processing.utils.http_request',
return_value=response) as mock_http_request:
self.assertEqual(
downloaders.HTTPDownloader.get_file_name(
'url', None, request_parameters={'foo': 'bar'}),
file_name)
mock_http_request.assert_called_once_with('HEAD', 'url', auth=None, params={'foo': 'bar'})

def test_get_file_name_without_parameters(self):
"""Test getting a file name without parameters in the HEAD
request
"""
file_name = "test_file.txt"
response = requests.Response()
response.status_code = 200
response.headers['Content-Disposition'] = f'inline;filename="{file_name}"'
with mock.patch(
'geospaas_processing.utils.http_request',
return_value=response) as mock_http_request:
self.assertEqual(downloaders.HTTPDownloader.get_file_name('url', None), file_name)
mock_http_request.assert_called_once_with('HEAD', 'url', auth=None, params={})

def test_connect(self):
"""Connect should return a Response object"""
response = requests.Response()
Expand All @@ -296,6 +347,20 @@ def test_connect(self):
connect_result = downloaders.HTTPDownloader.connect('url')
self.assertEqual(connect_result, response)

def test_connect_with_parameters(self):
"""Test connecting with parameters to the GET request"""
with mock.patch('geospaas_processing.utils.http_request') as mock_http_request:
downloaders.HTTPDownloader.connect('url', request_parameters={'appkey': 'foo'})
mock_http_request.assert_called_once_with(
'GET', 'url', stream=True, auth=(None, None), params={'appkey': 'foo'})

def test_connect_without_parameters(self):
"""Test connecting with parameters to the GET request"""
with mock.patch('geospaas_processing.utils.http_request') as mock_http_request:
downloaders.HTTPDownloader.connect('url')
mock_http_request.assert_called_once_with(
'GET', 'url', stream=True, auth=(None, None), params={})

def test_connect_error_code(self):
"""An exception should be raised when an error code is received
"""
Expand Down