diff --git a/darwin/backend_v2.py b/darwin/backend_v2.py index 535b2d498..615aff00f 100644 --- a/darwin/backend_v2.py +++ b/darwin/backend_v2.py @@ -1,6 +1,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union from urllib import parse +from requests.exceptions import HTTPError +from requests.models import Response +from tenacity import RetryCallState, retry, stop_after_attempt, wait_exponential_jitter + from darwin.datatypes import ItemId @@ -17,6 +21,19 @@ def wrapper(self, *args, **kwargs) -> Callable: return wrapper +def log_rate_limit_exceeded(retry_state: RetryCallState): + wait_time = retry_state.next_action.sleep + print(f"Rate limit exceeded. Retrying in {wait_time:.2f} seconds...") + + +def retry_if_status_code_429(retry_state: RetryCallState): + exception = retry_state.outcome.exception() + if isinstance(exception, HTTPError): + response: Response = exception.response + return response.status_code == 429 + return False + + class BackendV2: def __init__(self, client: "Client", default_team): # noqa F821 self._client = client @@ -238,6 +255,12 @@ def import_annotation( f"v2/teams/{team_slug}/items/{item_id}/import", payload=payload ) + @retry( + wait=wait_exponential_jitter(initial=60, max=300), + stop=stop_after_attempt(10), + retry=retry_if_status_code_429, + before_sleep=log_rate_limit_exceeded, + ) @inject_default_team_slug def register_items(self, payload: Dict[str, Any], team_slug: str) -> None: """ diff --git a/poetry.lock b/poetry.lock index dfba57e5d..2e8431cff 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "albucore" @@ -1863,6 +1863,21 @@ files = [ [package.dependencies] mpmath = ">=0.19" +[[package]] +name = "tenacity" +version = "8.3.0" +description = "Retry code until it succeeds" +optional = false +python-versions = ">=3.8" +files = [ + {file = "tenacity-8.3.0-py3-none-any.whl", hash = "sha256:3649f6443dbc0d9b01b9d8020a9c4ec7a1ff5f6f3c6c8a036ef371f573fe9185"}, + {file = "tenacity-8.3.0.tar.gz", hash = "sha256:953d4e6ad24357bceffbc9707bc74349aca9d245f68eb65419cf0c249a1949a2"}, +] + +[package.extras] +doc = ["reno", "sphinx"] +test = ["pytest", "tornado (>=4.5)", "typeguard"] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -2163,4 +2178,4 @@ test = ["pytest", "responses"] [metadata] lock-version = "2.0" python-versions = ">=3.8.0,<3.12" -content-hash = "e2acd10369632e81157fc873add2de947edd824f1f9ed2253b7d9598b3f965d5" +content-hash = "0d694891abf5df0e8e04d0a17e6ca5a3028c9a04fe079e36fe2920ae9f8a5dc6" diff --git a/pyproject.toml b/pyproject.toml index 78ea935c7..fbcfe1647 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,6 +81,7 @@ tqdm = "^4.64.1" types-pyyaml = "^6.0.12.9" types-requests = "^2.28.11.8" upolygon = "0.1.11" +tenacity = "^8.3.0" [tool.poetry.extras] dev = ["black", "isort", "flake8", "mypy", "debugpy", "responses", "pytest", "flake8-pyproject", "pytest-rerunfailures", "ruff", "validate-pyproject"] diff --git a/tests/darwin/backend_v2_test.py b/tests/darwin/backend_v2_test.py new file mode 100644 index 000000000..dc5ee6c00 --- /dev/null +++ b/tests/darwin/backend_v2_test.py @@ -0,0 +1,28 @@ +from unittest.mock import Mock, call, patch + +import pytest +from requests.exceptions import HTTPError +from requests.models import Response +from tenacity import RetryError + +from darwin.backend_v2 import BackendV2 + + +class TestBackendV2: + @patch("time.sleep", return_value=None) + def test_register_items_retries_on_429(self, mock_sleep): + mock_client = Mock() + mock_response = Mock(spec=Response) + mock_response.status_code = 429 + mock_client._post_raw.side_effect = HTTPError(response=mock_response) + + backend = BackendV2(mock_client, "team_slug") + + payload = {"key": "value"} + with pytest.raises(RetryError): + backend.register_items(payload) + + assert mock_client._post_raw.call_count == 10 + + expected_call = call("/v2/teams/team_slug/items/register_existing", payload) + assert mock_client._post_raw.call_args_list == [expected_call] * 10 diff --git a/tests/darwin/dataset/remote_dataset_test.py b/tests/darwin/dataset/remote_dataset_test.py index 1261715aa..9ffaa06ed 100644 --- a/tests/darwin/dataset/remote_dataset_test.py +++ b/tests/darwin/dataset/remote_dataset_test.py @@ -1065,7 +1065,7 @@ def test_register_files_with_blocked_items(self, remote_dataset: RemoteDatasetV2 }, status=200, ) - remote_dataset.register( + result = remote_dataset.register( ObjectStore( name="test", prefix="test_prefix", @@ -1076,3 +1076,5 @@ def test_register_files_with_blocked_items(self, remote_dataset: RemoteDatasetV2 {"item1": ["test.jpg"]}, multi_slotted=True, ) + assert len(result["registered"]) == 0 + assert len(result["blocked"]) == 1