diff --git a/paperqa/clients/semantic_scholar.py b/paperqa/clients/semantic_scholar.py index d6a74171a..53cd2e9d9 100644 --- a/paperqa/clients/semantic_scholar.py +++ b/paperqa/clients/semantic_scholar.py @@ -110,6 +110,23 @@ def make_url_params( # noqa: PLR0911 raise NotImplementedError +@retry( + retry=retry_if_exception(make_flaky_ssl_error_predicate(SEMANTIC_SCHOLAR_HOST)), + before_sleep=before_sleep_log(logger, logging.WARNING), + stop=stop_after_attempt(3), +) +async def _s2_get_with_retrying(url: str, **get_kwargs) -> dict[str, Any]: + return await _get_with_retrying( + url=url, + headers=get_kwargs.get("headers") or semantic_scholar_headers(), + timeout=( + get_kwargs.get("timeout") + or aiohttp.ClientTimeout(SEMANTIC_SCHOLAR_API_REQUEST_TIMEOUT) + ), + **get_kwargs, + ) + + def s2_authors_match(authors: list[str], data: dict) -> bool: """Check if the authors in the data match the authors in the paper.""" AUTHOR_NAME_MIN_LENGTH = 2 @@ -131,7 +148,7 @@ def s2_authors_match(authors: list[str], data: dict) -> bool: async def parse_s2_to_doc_details( - paper_data: dict, session: aiohttp.ClientSession + paper_data: dict[str, Any], session: aiohttp.ClientSession ) -> DocDetails: bibtex_source = "self_generated" @@ -217,12 +234,10 @@ async def s2_title_search( params={"query": title, "fields": fields} ) - data = await _get_with_retrying( + data = await _s2_get_with_retrying( url=endpoint, params=params, session=session, - headers=semantic_scholar_headers(), - timeout=SEMANTIC_SCHOLAR_API_REQUEST_TIMEOUT, http_exception_mappings={ HTTPStatus.NOT_FOUND: DOINotFoundError(f"Could not find DOI for {title}.") }, @@ -277,19 +292,18 @@ async def get_s2_doc_details_from_doi( else: s2_fields = SEMANTIC_SCHOLAR_API_FIELDS - details = await _get_with_retrying( - url=f"{SEMANTIC_SCHOLAR_BASE_URL}/graph/v1/paper/DOI:{doi}", - params={"fields": s2_fields}, + return await parse_s2_to_doc_details( + paper_data=await _s2_get_with_retrying( + url=f"{SEMANTIC_SCHOLAR_BASE_URL}/graph/v1/paper/DOI:{doi}", + params={"fields": s2_fields}, + session=session, + http_exception_mappings={ + HTTPStatus.NOT_FOUND: DOINotFoundError(f"Could not find DOI for {doi}.") + }, + ), session=session, - headers=semantic_scholar_headers(), - timeout=SEMANTIC_SCHOLAR_API_REQUEST_TIMEOUT, - http_exception_mappings={ - HTTPStatus.NOT_FOUND: DOINotFoundError(f"Could not find DOI for {doi}.") - }, ) - return await parse_s2_to_doc_details(details, session) - async def get_s2_doc_details_from_title( title: str | None, diff --git a/paperqa/utils.py b/paperqa/utils.py index fbcc48338..7e738c20b 100644 --- a/paperqa/utils.py +++ b/paperqa/utils.py @@ -409,20 +409,13 @@ def is_retryable(exc: BaseException) -> bool: ) async def _get_with_retrying( url: str, - params: dict[str, Any], session: aiohttp.ClientSession, - headers: dict[str, str] | None = None, - timeout: float = 10.0, # noqa: ASYNC109 http_exception_mappings: dict[HTTPStatus | int, Exception] | None = None, + **get_kwargs, ) -> dict[str, Any]: """Get from a URL with retrying protection.""" try: - async with session.get( - url, - params=params, - headers=headers, - timeout=aiohttp.ClientTimeout(timeout), - ) as response: + async with session.get(url, **get_kwargs) as response: response.raise_for_status() return await response.json() except aiohttp.ClientResponseError as e: