Skip to content

Commit

Permalink
Testing: Add overloads for get_rse_attribute; rucio#6588
Browse files Browse the repository at this point in the history
  • Loading branch information
rdimaio committed Oct 25, 2024
1 parent 431fb7f commit 5ff566f
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 46 deletions.
8 changes: 4 additions & 4 deletions lib/rucio/client/rseclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator

from rucio.common.constants import RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL, RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, SUPPORTED_PROTOCOLS_LITERAL
from rucio.common.constants import RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL, SUPPORTED_PROTOCOLS_LITERAL


class RSEClient(BaseClient):
Expand Down Expand Up @@ -238,7 +238,7 @@ def get_protocols(
self,
rse: str,
protocol_domain: "RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL" = 'ALL',
operation: Optional["RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL"] = None,
operation: Optional["RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL"] = None,
default: bool = False,
scheme: Optional['SUPPORTED_PROTOCOLS_LITERAL'] = None
) -> Any:
Expand Down Expand Up @@ -287,7 +287,7 @@ def lfns2pfns(
rse: str,
lfns: 'Iterable[str]',
protocol_domain: 'RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL' = 'ALL',
operation: Optional['RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL'] = None,
operation: Optional['RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL'] = None,
scheme: Optional['SUPPORTED_PROTOCOLS_LITERAL'] = None
) -> dict[str, str]:
"""
Expand Down Expand Up @@ -409,7 +409,7 @@ def swap_protocols(
self,
rse: str,
domain: 'RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL',
operation: 'RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL',
operation: 'RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL',
scheme_a: 'SUPPORTED_PROTOCOLS_LITERAL',
scheme_b: 'SUPPORTED_PROTOCOLS_LITERAL'
) -> bool:
Expand Down
60 changes: 58 additions & 2 deletions lib/rucio/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,11 @@

RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL = Literal['ALL', 'LAN', 'WAN']

RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL = Literal['read', 'write', 'delete', 'third_party_copy_read', 'third_party_copy_write']
RSE_SUPPORTED_PROTOCOL_OPERATIONS: list[str] = list(get_args(RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL))
RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL = Literal['read', 'write', 'delete']
RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS: list[str] = list(get_args(RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL))

RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL = Literal[RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, 'third_party_copy_read', 'third_party_copy_write']
RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: list[str] = list(get_args(RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL))

FTS_STATE = namedtuple('FTS_STATE', ['SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', 'NOT_USED',
'CANCELED'])('SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY',
Expand Down Expand Up @@ -159,3 +162,56 @@ class RseAttr:
DEFAULT_LIMIT_FILES = 'default_limit_files'
QUOTA_APPROVERS = 'quota_approvers'
RULE_DELETERS = 'rule_deleters'


# Literal types to allow overloading of functions with RSE attributes in their signature.
# RSE attributes are encoded via the BooleanString decorator as VARCHAR in the database,
# but they are used as either bool or string in the code.
# This is only determined at runtime, so for static type checking
# we need to manually specify which attrs are string and which are bool.
# In future, we could refactor RseAttr to avoid code duplication.
RSE_ATTRS_STR = Literal[
'archive_timeout',
'associated_sites',
'bittorrent_tracker_addr',
'country',
'decommission',
'default_account_limit_bytes',
'fts',
'globus_endpoint_id',
'lfn2pfn_algorithm',
'maximum_pin_lifetime',
'multihop_tombstone_delay',
'naming_convention',
'oidc_base_path',
'oidc_support'
'physgroup',
'qbittorrent_management_address'
'rule_approvers',
's3_url_style',
'simulate_multirange',
'site',
'source_for_total_space',
'source_for_used_space',
'staging_buffer',
'tombstone_delay',
'type'
]

RSE_ATTRS_BOOL = Literal[
'auto_approve_bytes',
'auto_approve_files',
'block_manual_approval',
'greedyDeletion',
'is_object_store',
'restricted_read',
'restricted_write',
'skip_upload_stat',
'staging_required',
'strict_copy',
'use_ipv4',
'verify_checksum'
]

SUPPORTED_SIGN_URL_SERVICES_LITERAL = Literal['gcs', 's3', 'swift']
SUPPORTED_SIGN_URL_SERVICES = list(get_args(SUPPORTED_SIGN_URL_SERVICES_LITERAL))
23 changes: 11 additions & 12 deletions lib/rucio/core/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import hmac
import time
from hashlib import sha1
from typing import Literal, Optional
from typing import Optional
from urllib.parse import urlencode, urlparse

import boto3
Expand All @@ -27,7 +27,7 @@

from rucio.common.cache import MemcacheRegion
from rucio.common.config import config_get, get_rse_credentials
from rucio.common.constants import RseAttr
from rucio.common.constants import RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS, RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, SUPPORTED_SIGN_URL_SERVICES, SUPPORTED_SIGN_URL_SERVICES_LITERAL, RseAttr
from rucio.common.exception import UnsupportedOperation
from rucio.core.monitor import MetricManager
from rucio.core.rse import get_rse_attribute
Expand All @@ -40,8 +40,8 @@

def get_signed_url(
rse_id: str,
service: Literal['gsc', 's3', 'swift'],
operation: Literal['read', 'write', 'delete'],
service: SUPPORTED_SIGN_URL_SERVICES_LITERAL,
operation: RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL,
url: str,
lifetime: Optional[int] = 600
) -> str:
Expand All @@ -60,10 +60,10 @@ def get_signed_url(

global CREDS_GCS

if service not in ['gcs', 's3', 'swift']:
if service not in SUPPORTED_SIGN_URL_SERVICES:
raise UnsupportedOperation('Service must be "gcs", "s3" or "swift"')

if operation not in ['read', 'write', 'delete']:
if operation not in RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS:
raise UnsupportedOperation('Operation must be "read", "write", or "delete"')

if url is None or url == '':
Expand All @@ -84,10 +84,6 @@ def get_signed_url(
components = urlparse(url)
host = components.netloc

# select the correct operation
operations = {'read': 'GET', 'write': 'PUT', 'delete': 'DELETE'}
operation = operations[operation]

# special case to test signature, force epoch time
if lifetime is None:
lifetime = 0
Expand All @@ -100,8 +96,11 @@ def get_signed_url(
# sign the path only
path = components.path

# Map operations
operations = {'read': 'GET', 'write': 'PUT', 'delete': 'DELETE'}

# assemble message to sign
to_sign = "%s\n\n\n%s\n%s" % (operation, lifetime, path)
to_sign = "%s\n\n\n%s\n%s" % (operations[operation], lifetime, path)

# create URL-capable signature
# first character is always a '=', remove it
Expand Down Expand Up @@ -213,7 +212,7 @@ def get_signed_url(
else:
swiftop = 'DELETE'

expires = int(time.time() + lifetime)
expires = int(time.time() + lifetime) # type: ignore (lifetime could be None)

# create signed URL
with METRICS.timer('signswift'):
Expand Down
44 changes: 37 additions & 7 deletions lib/rucio/core/rse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from datetime import datetime
from io import StringIO
from re import match
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload

import sqlalchemy
from dogpile.cache.api import NO_VALUE
Expand All @@ -28,7 +28,7 @@
from rucio.common import exception, types, utils
from rucio.common.cache import MemcacheRegion
from rucio.common.config import get_lfn2pfn_algorithm_default
from rucio.common.constants import RSE_SUPPORTED_PROTOCOL_OPERATIONS, RseAttr
from rucio.common.constants import RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS, RSE_ATTRS_BOOL, RSE_ATTRS_STR, SUPPORTED_SIGN_URL_SERVICES_LITERAL, RseAttr
from rucio.common.utils import CHECKSUM_KEY, GLOBALLY_SUPPORTED_CHECKSUMS, Availability
from rucio.core.rse_counter import add_counter, get_counter
from rucio.db.sqla import models
Expand Down Expand Up @@ -886,7 +886,7 @@ def has_rse_attribute(rse_id, key, *, session: "Session"):


@read_session
def get_rses_with_attribute(key, *, session: "Session"):
def get_rses_with_attribute(key, *, session: "Session") -> list[dict[str, Any]]:
"""
Return all RSEs with a certain attribute.
Expand Down Expand Up @@ -959,6 +959,36 @@ def get_rses_with_attribute_value(key, value, vo='def', *, session: "Session"):
return result


@overload
def get_rse_attribute(rse_id: str, key: Literal['sign_url'], use_cache: bool = True, *, session: "Session") -> Optional[SUPPORTED_SIGN_URL_SERVICES_LITERAL]:
...


@overload
def get_rse_attribute(rse_id: str, key: Literal['sign_url'], use_cache: bool = True) -> Optional[SUPPORTED_SIGN_URL_SERVICES_LITERAL]:
...


@overload
def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_STR', use_cache: bool = True) -> Optional[str]:
...


@overload
def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_STR', use_cache: bool = True, *, session: "Session") -> Optional[str]:
...


@overload
def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_BOOL', use_cache: bool = True) -> Optional[bool]:
...


@overload
def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_BOOL', use_cache: bool = True, *, session: "Session") -> Optional[bool]:
...


@read_session
def get_rse_attribute(rse_id: str, key: str, use_cache: bool = True, *, session: "Session") -> Optional[Union[str, bool]]:
"""
Expand Down Expand Up @@ -1290,7 +1320,7 @@ def add_protocol(
if domain not in utils.rse_supported_protocol_domains():
raise exception.RSEProtocolDomainNotSupported(f"The protocol domain '{domain}' is not defined in the schema.")
for op in parameter['domains'][domain]:
if op not in RSE_SUPPORTED_PROTOCOL_OPERATIONS:
if op not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS:
raise exception.RSEOperationNotSupported(f"Operation '{op}' not defined in schema.")
op_name = op if op.startswith('third_party_copy') else f'{op}_{domain}'.lower()
priority = parameter['domains'][domain][op]
Expand Down Expand Up @@ -1414,7 +1444,7 @@ def _format_get_rse_protocols(
'verify_checksum': verify_checksum if verify_checksum is not None else True,
'volatile': _rse['volatile']}

for op in RSE_SUPPORTED_PROTOCOL_OPERATIONS:
for op in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS:
info['%s_protocol' % op] = 1 # 1 indicates the default protocol

for row in db_protocols:
Expand Down Expand Up @@ -1501,7 +1531,7 @@ def update_protocols(
if domain not in utils.rse_supported_protocol_domains():
raise exception.RSEProtocolDomainNotSupported(f"The protocol domain '{domain}' is not defined in the schema.")
for op in data['domains'][domain]:
if op not in RSE_SUPPORTED_PROTOCOL_OPERATIONS:
if op not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS:
raise exception.RSEOperationNotSupported(f"Operation '{op}' not defined in schema.")
op_name = op if op.startswith('third_party_copy') else f'{op}_{domain}'.lower()
priority = data['domains'][domain][op]
Expand Down Expand Up @@ -1869,7 +1899,7 @@ def determine_scope_for_rse(
# a base which should be removed from the prefix (in order for '/' to
# mean the entire resource associated with that issuer).
prefix = protocol['prefix']
if base_path := get_rse_attribute(rse_id, RseAttr.OIDC_BASE_PATH):
if base_path := get_rse_attribute(rse_id, RseAttr.OIDC_BASE_PATH): # type: ignore (session parameter missing)
prefix = prefix.removeprefix(base_path)
filtered_prefixes.add(prefix)
all_scopes = [f'{s}:{p}' for s in scopes for p in filtered_prefixes] + list(extra_scopes)
Expand Down
10 changes: 5 additions & 5 deletions lib/rucio/core/rse_expression_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import abc
import re
from hashlib import sha256
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from dogpile.cache.api import NoValue

Expand Down Expand Up @@ -286,15 +286,15 @@ def resolve_elements(self, session):
"""
Inherited from :py:func:`BaseExpressionElement.resolve_elements`
"""
rse_list = get_rses_with_attribute(key=self.key, session=session)
rse_list: list[dict[str, Any]] = get_rses_with_attribute(key=self.key, session=session)
if not rse_list:
return (set(), {})

output = []
rse_dict = {}
for rse in rse_list:
try:
if float(get_rse_attribute(rse['id'], self.key, session=session)) < float(self.value):
if float(get_rse_attribute(rse['id'], self.key, session=session)) < float(self.value): # type: ignore (get_rse_attribute could return None)
rse_dict[rse['id']] = rse
output.append(rse['id'])
except ValueError:
Expand All @@ -321,15 +321,15 @@ def resolve_elements(self, session):
"""
Inherited from :py:func:`BaseExpressionElement.resolve_elements`
"""
rse_list = get_rses_with_attribute(key=self.key, session=session)
rse_list: list[dict[str, Any]] = get_rses_with_attribute(key=self.key, session=session)
if not rse_list:
return (set(), {})

output = []
rse_dict = {}
for rse in rse_list:
try:
if float(get_rse_attribute(rse['id'], self.key, session=session)) > float(self.value):
if float(get_rse_attribute(rse['id'], self.key, session=session)) > float(self.value): # type: ignore (get_rse_attribute could return None)
rse_dict[rse['id']] = rse
output.append(rse['id'])
except ValueError:
Expand Down
2 changes: 1 addition & 1 deletion lib/rucio/daemons/bb8/nuclei_background_rebalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def group_space(site: str) -> int:
global_ratio = float(0)
for rse in rses:
site_name = get_rse_attribute(rse['id'], RseAttr.SITE)
rse['groupdisk'] = group_space(site_name)
rse['groupdisk'] = group_space(site_name) # type: ignore (site_name could be None)
rse['primary'] = get_rse_usage(rse_id=rse['id'], source='rucio')[0]['used'] - get_rse_usage(rse_id=rse['id'], source='expired')[0]['used']
rse['primary'] += rse['groupdisk']
rse['secondary'] = get_rse_usage(rse_id=rse['id'], source='expired')[0]['used']
Expand Down
2 changes: 1 addition & 1 deletion lib/rucio/daemons/bb8/t2_background_rebalance.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def group_space(site: str) -> int:
global_ratio = float(0)
for rse in rses:
site_name = get_rse_attribute(rse['id'], RseAttr.SITE)
rse['groupdisk'] = group_space(site_name)
rse['groupdisk'] = group_space(site_name) # type: ignore (site_name could be None)
rse['primary'] = get_rse_usage(rse_id=rse['id'], source='rucio')[0]['used'] - get_rse_usage(rse_id=rse['id'], source='expired')[0]['used']
rse['primary'] += rse['groupdisk']
rse['secondary'] = get_rse_usage(rse_id=rse['id'], source='expired')[0]['used']
Expand Down
2 changes: 1 addition & 1 deletion lib/rucio/daemons/rsedecommissioner/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def set_status(
:param rse_id: RSE ID.
:param status: RSE decommissioning status.
"""
config = attr_to_config(get_rse_attribute(rse_id, RseAttr.DECOMMISSION))
config = attr_to_config(get_rse_attribute(rse_id, RseAttr.DECOMMISSION)) # type: ignore (get_rse_attribute could return None)
config['status'] = status
# add_rse_attribute can handle updating existing entries too
add_rse_attribute(rse_id, RseAttr.DECOMMISSION, config_to_attr(config))
2 changes: 1 addition & 1 deletion lib/rucio/daemons/rsedecommissioner/rse_decommissioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def run_once(
# Get the decommission attribute (encodes the decommissioning config)
attr = get_rse_attribute(rse['id'], RseAttr.DECOMMISSION)
try:
config = attr_to_config(attr)
config = attr_to_config(attr) # type: ignore (attr could be None)
except InvalidStatusName:
logger(logging.ERROR, 'RSE %s has an invalid decommissioning status',
rse['rse'])
Expand Down
8 changes: 5 additions & 3 deletions lib/rucio/gateway/credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Literal
from typing import TYPE_CHECKING

from rucio.common import exception
from rucio.core import credential
Expand All @@ -23,15 +23,17 @@
if TYPE_CHECKING:
from sqlalchemy.orm import Session

from rucio.common.constants import RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, SUPPORTED_SIGN_URL_SERVICES_LITERAL


@read_session
def get_signed_url(
account: str,
appid: str,
ip: str,
rse: str,
service: Literal['gsc', 's3', 'swift'],
operation: Literal['read', 'write', 'delete'],
service: 'SUPPORTED_SIGN_URL_SERVICES_LITERAL',
operation: 'RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL',
url: str,
lifetime: int,
vo: str = 'def',
Expand Down
Loading

0 comments on commit 5ff566f

Please sign in to comment.