Skip to content

Commit

Permalink
Merge branch 'tk9'
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiobatalha committed Oct 10, 2017
2 parents 187a8f5 + e13e1ae commit 32323dc
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 46 deletions.
2 changes: 1 addition & 1 deletion crossref/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
VERSION = '1.1.1'
VERSION = '1.2.0'
126 changes: 82 additions & 44 deletions crossref/restful.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

import requests
import json
from time import sleep

from datetime import datetime, timedelta

from crossref import validators, VERSION

LIMIT = 100
MAXOFFSET = 10000
FACETS_MAX_LIMIT = 1000

API = "api.crossref.org"


Expand All @@ -23,27 +27,61 @@ class UrlSyntaxError(CrossrefAPIError, ValueError):
pass


def do_http_request(method, endpoint, data=None, files=None, timeout=10, only_headers=False, custom_header=None):
class HTTPRequest(object):

THROTTLING_TUNNING_TIME = 600

def __init__(self, throttle=True):
self.throttle = throttle
self.rate_limits = {
'X-Rate-Limit-Limit': 50,
'X-Rate-Limit-Interval': 1
}

def _update_rate_limits(self, headers):

self.rate_limits['X-Rate-Limit-Limit'] = int(headers.get('X-Rate-Limit-Limit', 50))

interval_value = int(headers.get('X-Rate-Limit-Interval', '1s')[:-1])
interval_scope = headers.get('X-Rate-Limit-Interval', '1s')[-1]

if interval_scope == 'm':
interval_value = interval_value * 60

if only_headers is True:
return requests.head(endpoint)
if interval_scope == 'h':
interval_value = interval_value * 60 * 60

if method == 'post':
action = requests.post
else:
action = requests.get
self.rate_limits['X-Rate-Limit-Interval'] = interval_value

if custom_header:
headers = {'user-agent': custom_header}
else:
headers = {'user-agent': str(Etiquette())}
@property
def throttling_time(self):
return self.rate_limits['X-Rate-Limit-Interval'] / self.rate_limits['X-Rate-Limit-Limit']

if method == 'post':
result = action(endpoint, data=data, files=files, timeout=timeout, headers=headers)
else:
result = action(endpoint, params=data, timeout=timeout, headers=headers)
def do_http_request(self, method, endpoint, data=None, files=None, timeout=100, only_headers=False, custom_header=None):

return result
if only_headers is True:
return requests.head(endpoint)

if method == 'post':
action = requests.post
else:
action = requests.get

if custom_header:
headers = {'user-agent': custom_header}
else:
headers = {'user-agent': str(Etiquette())}

if method == 'post':
result = action(endpoint, data=data, files=files, timeout=timeout, headers=headers)
else:
result = action(endpoint, params=data, timeout=timeout, headers=headers)

if self.throttle is True:
self._update_rate_limits(result.headers)
sleep(self.throttling_time)

return result


def build_url_endpoint(endpoint, context=None):
Expand All @@ -56,7 +94,6 @@ def build_url_endpoint(endpoint, context=None):
class Etiquette:

def __init__(self, application_name='undefined', application_version='undefined', application_url='undefined', contact_email='anonymous'):

self.application_name = application_name
self.application_version = application_version
self.application_url = application_url
Expand All @@ -77,8 +114,8 @@ class Endpoint:

CURSOR_AS_ITER_METHOD = False

def __init__(self, request_url=None, request_params=None, context=None, etiquette=None):

def __init__(self, request_url=None, request_params=None, context=None, etiquette=None, throttle=True):
self.do_http_request = HTTPRequest(throttle=throttle).do_http_request
self.etiquette = etiquette or Etiquette()
self.request_url = request_url or build_url_endpoint(self.ENDPOINT, context)
self.request_params = request_params or dict()
Expand All @@ -89,11 +126,12 @@ def _rate_limits(self):
request_params = dict(self.request_params)
request_url = str(self.request_url)

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
only_headers=True,
custom_header=str(self.etiquette)
custom_header=str(self.etiquette),
throttle=False
)

rate_limits = {
Expand Down Expand Up @@ -126,7 +164,7 @@ def version(self):
request_params = dict(self.request_params)
request_url = str(self.request_url)

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -167,7 +205,7 @@ def count(self):
request_url = str(self.request_url)
request_params['rows'] = 0

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -216,7 +254,7 @@ def __iter__(self):

if 'sample' in self.request_params:
request_params = self._escaped_pagging()
result = do_http_request(
result = self.do_http_request(
'get',
self.request_url,
data=request_params,
Expand All @@ -238,7 +276,7 @@ def __iter__(self):
request_params['cursor'] = '*'
request_params['rows'] = LIMIT
while True:
result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand All @@ -262,7 +300,7 @@ def __iter__(self):
request_params['offset'] = 0
request_params['rows'] = LIMIT
while True:
result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -763,7 +801,7 @@ def facet(self, facet_name, facet_count=100):
facet_count = self.FACET_VALUES[facet_name] if self.FACET_VALUES[facet_name] is not None and self.FACET_VALUES[facet_name] <= facet_count else facet_count

request_params['facet'] = '%s:%s' % (facet_name, facet_count)
result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -906,7 +944,7 @@ def doi(self, doi, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -940,7 +978,7 @@ def agency(self, doi, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -980,7 +1018,7 @@ def doi_exists(self, doi):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1105,7 +1143,7 @@ def funder(self, funder_id, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1145,7 +1183,7 @@ def funder_exists(self, funder_id):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1324,7 +1362,7 @@ def member(self, member_id, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1364,7 +1402,7 @@ def member_exists(self, member_id):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1411,7 +1449,7 @@ def type(self, type_id, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1447,7 +1485,7 @@ def all(self):
request_url = build_url_endpoint(self.ENDPOINT, self.context)
request_params = dict(self.request_params)

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1488,7 +1526,7 @@ def type_exists(self, type_id):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1539,7 +1577,7 @@ def prefix(self, prefix_id, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1623,7 +1661,7 @@ def journal(self, issn, only_message=True):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1664,7 +1702,7 @@ def journal_exists(self, issn):
)
request_params = {}

result = do_http_request(
result = self.do_http_request(
'get',
request_url,
data=request_params,
Expand Down Expand Up @@ -1693,7 +1731,7 @@ def works(self, issn):
class Depositor(object):

def __init__(self, prefix, api_user, api_key, etiquette=None):

self.do_http_request = HTTPRequest(throttle=False).do_http_request
self.etiquette = etiquette or Etiquette()
self.prefix = prefix
self.api_user = api_user
Expand Down Expand Up @@ -1723,7 +1761,7 @@ def register_doi(self, submission_id, request_xml):
'login_passwd': self.api_key
}

result = do_http_request(
result = self.do_http_request(
'post',
endpoint,
data=params,
Expand Down Expand Up @@ -1754,7 +1792,7 @@ def request_doi_status_by_filename(self, file_name, data_type='result'):
'type': data_type
}

result = do_http_request(
result = self.do_http_request(
'get',
endpoint,
data=params,
Expand Down Expand Up @@ -1784,7 +1822,7 @@ def request_doi_status_by_batch_id(self, doi_batch_id, data_type='result'):
'type': data_type
}

result = do_http_request(
result = self.do_http_request(
'get',
endpoint,
data=params,
Expand Down
45 changes: 44 additions & 1 deletion tests/test_restful.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,47 @@ def test_members_filters(self):
def test_funders_filters(self):
result = restful.Funders(etiquette=self.etiquette).filter(location="Japan").url

self.assertEqual(result, 'https://api.crossref.org/funders?filter=location%3AJapan')
self.assertEqual(result, 'https://api.crossref.org/funders?filter=location%3AJapan')


class HTTPRequestTest(unittest.TestCase):

def setUp(self):

self.httprequest = restful.HTTPRequest()

def test_default_rate_limits(self):

expected = {'X-Rate-Limit-Interval': 1, 'X-Rate-Limit-Limit': 50}

self.assertEqual(self.httprequest.rate_limits, expected)

def test_update_rate_limits_seconds(self):

headers = {'X-Rate-Limit-Interval': '2s', 'X-Rate-Limit-Limit': 50}

self.httprequest._update_rate_limits(headers)

expected = {'X-Rate-Limit-Interval': 2, 'X-Rate-Limit-Limit': 50}

self.assertEqual(self.httprequest.rate_limits, expected)

def test_update_rate_limits_minutes(self):

headers = {'X-Rate-Limit-Interval': '2m', 'X-Rate-Limit-Limit': 50}

self.httprequest._update_rate_limits(headers)

expected = {'X-Rate-Limit-Interval': 120, 'X-Rate-Limit-Limit': 50}

self.assertEqual(self.httprequest.rate_limits, expected)

def test_update_rate_limits_hours(self):

headers = {'X-Rate-Limit-Interval': '2h', 'X-Rate-Limit-Limit': 50}

self.httprequest._update_rate_limits(headers)

expected = {'X-Rate-Limit-Interval': 7200, 'X-Rate-Limit-Limit': 50}

self.assertEqual(self.httprequest.rate_limits, expected)

0 comments on commit 32323dc

Please sign in to comment.