Skip to content

Commit

Permalink
Mock requests.Session.get in TestClient (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasschwab authored Oct 25, 2023
1 parent ece3022 commit 8ef0759
Showing 1 changed file with 35 additions and 33 deletions.
68 changes: 35 additions & 33 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
import arxiv
from datetime import datetime, timedelta
from pytest import approx
from requests import Response

def empty_response(code: int) -> Response:
r = Response()
r.status_code = code
r._content = b''
return r

class TestClient(unittest.TestCase):
def test_invalid_format_id(self):
Expand Down Expand Up @@ -90,10 +96,10 @@ def test_no_duplicates(self):
self.assertFalse(r.entry_id in ids)
ids.add(r.entry_id)

@patch('requests.Session.get', return_value=empty_response(500))
@patch("time.sleep", return_value=None)
def test_retry(self, patched_time_sleep):
broken_client = TestClient.get_code_client(500)

def test_retry(self, mock_sleep, mock_get):
broken_client = arxiv.Client()
def broken_get():
search = arxiv.Search(query="quantum")
return next(broken_client.results(search))
Expand All @@ -109,77 +115,73 @@ def broken_get():
self.assertEqual(e.status, 500)
self.assertEqual(e.retry, broken_client.num_retries)

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_standard(self, patched_time_sleep):
client = TestClient.get_code_client(200)
def test_sleep_standard(self, mock_sleep, mock_get):
client = arxiv.Client()
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
# A client should sleep until delay_seconds have passed.
client._parse_feed(url)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()
# Overwrite _last_request_dt to minimize flakiness: different
# environments will have different page fetch times.
client._last_request_dt = datetime.now()
client._parse_feed(url)
patched_time_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))
mock_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_multiple_requests(self, patched_time_sleep):
client = TestClient.get_code_client(200)
def test_sleep_multiple_requests(self, mock_sleep, mock_get):
client = arxiv.Client()
url1 = client._format_url(arxiv.Search(query="quantum"), 0, 1)
url2 = client._format_url(arxiv.Search(query="testing"), 0, 1)
# Rate limiting is URL-independent; expect same behavior as in
# `test_sleep_standard`.
client._parse_feed(url1)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()
client._last_request_dt = datetime.now()
client._parse_feed(url2)
patched_time_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))
mock_sleep.assert_called_once_with(approx(client.delay_seconds, rel=1e-3))

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_elapsed(self, patched_time_sleep):
client = TestClient.get_code_client(200)
def test_sleep_elapsed(self, mock_sleep, mock_get):
client = arxiv.Client()
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
# If _last_request_dt is less than delay_seconds ago, sleep.
client._last_request_dt = datetime.now() - timedelta(seconds=client.delay_seconds - 1)
client._parse_feed(url)
patched_time_sleep.assert_called_once()
patched_time_sleep.reset_mock()
mock_sleep.assert_called_once()
mock_sleep.reset_mock()
# If _last_request_dt is at least delay_seconds ago, don't sleep.
client._last_request_dt = datetime.now() - timedelta(seconds=client.delay_seconds)
client._parse_feed(url)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()

@patch('requests.Session.get', return_value=empty_response(200))
@patch("time.sleep", return_value=None)
def test_sleep_zero_delay(self, patched_time_sleep):
client = TestClient.get_code_client(code=200, delay_seconds=0)
def test_sleep_zero_delay(self, mock_sleep, mock_get):
client = arxiv.Client(delay_seconds=0)
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
client._parse_feed(url)
client._parse_feed(url)
patched_time_sleep.assert_not_called()
mock_sleep.assert_not_called()

@patch('requests.Session.get', return_value=empty_response(500))
@patch("time.sleep", return_value=None)
def test_sleep_between_errors(self, patched_time_sleep):
client = TestClient.get_code_client(500)
def test_sleep_between_errors(self, mock_sleep, mock_get):
client = arxiv.Client()
url = client._format_url(arxiv.Search(query="quantum"), 0, 1)
try:
client._parse_feed(url)
except arxiv.HTTPError:
pass
# Should sleep between retries.
patched_time_sleep.assert_called()
self.assertEqual(patched_time_sleep.call_count, client.num_retries)
patched_time_sleep.assert_has_calls(
mock_sleep.assert_called()
self.assertEqual(mock_sleep.call_count, client.num_retries)
mock_sleep.assert_has_calls(
[
call(approx(client.delay_seconds, abs=1e-2)),
]
* client.num_retries
)

def get_code_client(code: int, delay_seconds=0.1, num_retries=3) -> arxiv.Client:
"""
get_code_client returns an arxiv.Cient with HTTP requests routed to
httpstat.us.
"""
client = arxiv.Client(delay_seconds=delay_seconds, num_retries=num_retries)
client.query_url_format = "https://teapot.fly.dev/{}?".format(code) + "{}"
return client

0 comments on commit 8ef0759

Please sign in to comment.