Skip to content

Commit

Permalink
Testing: Add type annotations to baseclient; rucio#6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Nov 15, 2024
1 parent 3f8b419 commit 64bfe0d
Showing 1 changed file with 44 additions and 20 deletions.
64 changes: 44 additions & 20 deletions lib/rucio/client/baseclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from os import environ, fdopen, geteuid, makedirs, path
from shutil import move
from tempfile import mkstemp
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union
from urllib.parse import urlparse

import requests
Expand All @@ -43,9 +43,11 @@
from rucio.common.utils import build_url, get_tmp_dir, my_key_generator, parse_response, setup_logger, ssh_sign

if TYPE_CHECKING:
from collections.abc import Generator
from collections.abc import Iterator, Sequence
from logging import Logger

from requests.structures import CaseInsensitiveDict

EXTRA_MODULES = import_extras(['requests_kerberos'])

if EXTRA_MODULES['requests_kerberos']:
Expand All @@ -63,11 +65,11 @@


@REGION.cache_on_arguments(namespace='host_to_choose')
def choice(hosts):
def choice(hosts: "Sequence[str]") -> str:
"""
Select randomly a host
Randomly select a host
:param hosts: Lost of hosts
:param hosts: List of hosts
:return: A randomly selected host.
"""
return secrets.choice(hosts)
Expand All @@ -89,8 +91,8 @@ def __init__(self,
ca_cert: Optional[str] = None,
auth_type: Optional[str] = None,
creds: Optional[dict[str, Any]] = None,
timeout: Optional[int] = 600,
user_agent: Optional[str] = 'rucio-clients',
timeout: int = 600,
user_agent: str = 'rucio-clients',
vo: Optional[str] = None,
logger: 'Logger' = LOG) -> None:
"""
Expand Down Expand Up @@ -332,7 +334,12 @@ def _get_creds(self, creds: Optional[dict[str, Any]]) -> dict[str, Any]:

return creds

def _get_exception(self, headers: dict[str, str], status_code: Optional[int] = None, data=None) -> tuple[type[exception.RucioException], str]:
def _get_exception(
self,
headers: "CaseInsensitiveDict",
status_code: Optional[int] = None,
data: Optional[Union[str, bytes, bytearray]] = None
) -> tuple[type[exception.RucioException], str]:
"""
Helper method to parse an error string send by the server and transform it into the corresponding rucio exception.
Expand All @@ -344,20 +351,20 @@ def _get_exception(self, headers: dict[str, str], status_code: Optional[int] = N
"""
if data is not None:
try:
data = parse_response(data)
parsed_data = parse_response(data)
except ValueError:
data = {}
parsed_data = {}
else:
data = {}
parsed_data = {}

exc_cls = 'RucioException'
exc_msg = 'no error information passed (http status code: %s)' % status_code
if 'ExceptionClass' in data:
exc_cls = data['ExceptionClass']
if 'ExceptionClass' in parsed_data:
exc_cls = parsed_data['ExceptionClass']
elif 'ExceptionClass' in headers:
exc_cls = headers['ExceptionClass']
if 'ExceptionMessage' in data:
exc_msg = data['ExceptionMessage']
if 'ExceptionMessage' in parsed_data:
exc_msg = parsed_data['ExceptionMessage']
elif 'ExceptionMessage' in headers:
exc_msg = headers['ExceptionMessage']

Expand All @@ -366,7 +373,7 @@ def _get_exception(self, headers: dict[str, str], status_code: Optional[int] = N
else:
return exception.RucioException, "%s: %s" % (exc_cls, exc_msg)

def _load_json_data(self, response: requests.Response) -> 'Generator[Any, Any, Any]':
def _load_json_data(self, response: requests.Response) -> "Iterator[Any]":
"""
Helper method to correctly load json data based on the content type of the http response.
Expand All @@ -382,8 +389,15 @@ def _load_json_data(self, response: requests.Response) -> 'Generator[Any, Any, A
if response.text:
yield response.text

def _reduce_data(self, data, maxlen: int = 132) -> str:
text = data if isinstance(data, str) else data.decode("utf-8")
def _reduce_data(
self,
data: Union[str, bytes, bytearray],
maxlen: int = 132
) -> str:
if isinstance(data, str):
text = data
if isinstance(data, (bytes, bytearray)):
text = data.decode("utf-8")
if len(text) > maxlen:
text = "%s ... %s" % (text[:maxlen - 15], text[-10:])
return text
Expand All @@ -398,8 +412,18 @@ def _back_off(self, retry_number: int, reason: str) -> None:
self.logger.warning("Waiting {}s due to reason: {} ".format(sleep_time, reason))
time.sleep(sleep_time)

def _send_request(self, url, headers=None, type_='GET', data=None, params=None, stream=False, get_token=False,
cert=None, auth=None, verify=None):
def _send_request(
self,
url: Union[str, bytes],
headers: Optional[dict[str, str]] = None,
type_: str = 'GET',
data: Optional[Union[str, dict[str, Any]]] = None,
params: Optional[dict[str, Any]] = None,
stream: bool = False,
get_token: bool = False,
cert: Optional[Union[str, tuple[str, str]]] = None,
auth: Optional[Any] = None,
verify: Optional[Union[bool, str]] = None):
"""
Helper method to send requests to the rucio server. Gets a new token and retries if an unauthorized error is returned.
Expand Down

0 comments on commit 64bfe0d

Please sign in to comment.