From 5ff566fec41d9845801950b35183b35d878538fb Mon Sep 17 00:00:00 2001 From: rdimaio Date: Tue, 16 Jul 2024 10:56:01 +0200 Subject: [PATCH] Testing: Add overloads for get_rse_attribute; #6588 --- lib/rucio/client/rseclient.py | 8 +-- lib/rucio/common/constants.py | 60 ++++++++++++++++++- lib/rucio/core/credential.py | 23 ++++--- lib/rucio/core/rse.py | 44 +++++++++++--- lib/rucio/core/rse_expression_parser.py | 10 ++-- .../bb8/nuclei_background_rebalance.py | 2 +- .../daemons/bb8/t2_background_rebalance.py | 2 +- lib/rucio/daemons/rsedecommissioner/config.py | 2 +- .../rsedecommissioner/rse_decommissioner.py | 2 +- lib/rucio/gateway/credential.py | 8 ++- lib/rucio/rse/rsemanager.py | 8 +-- lib/rucio/web/rest/flaskapi/v1/credentials.py | 5 +- tests/test_rse.py | 6 +- 13 files changed, 134 insertions(+), 46 deletions(-) diff --git a/lib/rucio/client/rseclient.py b/lib/rucio/client/rseclient.py index e82b1b6e18..b415e2202d 100644 --- a/lib/rucio/client/rseclient.py +++ b/lib/rucio/client/rseclient.py @@ -24,7 +24,7 @@ if TYPE_CHECKING: from collections.abc import Iterable, Iterator - from rucio.common.constants import RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL, RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, SUPPORTED_PROTOCOLS_LITERAL + from rucio.common.constants import RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL, SUPPORTED_PROTOCOLS_LITERAL class RSEClient(BaseClient): @@ -238,7 +238,7 @@ def get_protocols( self, rse: str, protocol_domain: "RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL" = 'ALL', - operation: Optional["RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL"] = None, + operation: Optional["RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL"] = None, default: bool = False, scheme: Optional['SUPPORTED_PROTOCOLS_LITERAL'] = None ) -> Any: @@ -287,7 +287,7 @@ def lfns2pfns( rse: str, lfns: 'Iterable[str]', protocol_domain: 'RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL' = 'ALL', - operation: Optional['RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL'] = None, + operation: Optional['RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL'] = None, scheme: Optional['SUPPORTED_PROTOCOLS_LITERAL'] = None ) -> dict[str, str]: """ @@ -409,7 +409,7 @@ def swap_protocols( self, rse: str, domain: 'RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL', - operation: 'RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL', + operation: 'RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL', scheme_a: 'SUPPORTED_PROTOCOLS_LITERAL', scheme_b: 'SUPPORTED_PROTOCOLS_LITERAL' ) -> bool: diff --git a/lib/rucio/common/constants.py b/lib/rucio/common/constants.py index 227f6078d0..66c33e1670 100644 --- a/lib/rucio/common/constants.py +++ b/lib/rucio/common/constants.py @@ -56,8 +56,11 @@ RSE_SUPPORTED_PROTOCOL_DOMAINS_LITERAL = Literal['ALL', 'LAN', 'WAN'] -RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL = Literal['read', 'write', 'delete', 'third_party_copy_read', 'third_party_copy_write'] -RSE_SUPPORTED_PROTOCOL_OPERATIONS: list[str] = list(get_args(RSE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL)) +RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL = Literal['read', 'write', 'delete'] +RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS: list[str] = list(get_args(RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL)) + +RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL = Literal[RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, 'third_party_copy_read', 'third_party_copy_write'] +RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: list[str] = list(get_args(RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL)) FTS_STATE = namedtuple('FTS_STATE', ['SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', 'NOT_USED', 'CANCELED'])('SUBMITTED', 'READY', 'ACTIVE', 'FAILED', 'FINISHED', 'FINISHEDDIRTY', @@ -159,3 +162,56 @@ class RseAttr: DEFAULT_LIMIT_FILES = 'default_limit_files' QUOTA_APPROVERS = 'quota_approvers' RULE_DELETERS = 'rule_deleters' + + +# Literal types to allow overloading of functions with RSE attributes in their signature. +# RSE attributes are encoded via the BooleanString decorator as VARCHAR in the database, +# but they are used as either bool or string in the code. +# This is only determined at runtime, so for static type checking +# we need to manually specify which attrs are string and which are bool. +# In future, we could refactor RseAttr to avoid code duplication. +RSE_ATTRS_STR = Literal[ + 'archive_timeout', + 'associated_sites', + 'bittorrent_tracker_addr', + 'country', + 'decommission', + 'default_account_limit_bytes', + 'fts', + 'globus_endpoint_id', + 'lfn2pfn_algorithm', + 'maximum_pin_lifetime', + 'multihop_tombstone_delay', + 'naming_convention', + 'oidc_base_path', + 'oidc_support' + 'physgroup', + 'qbittorrent_management_address' + 'rule_approvers', + 's3_url_style', + 'simulate_multirange', + 'site', + 'source_for_total_space', + 'source_for_used_space', + 'staging_buffer', + 'tombstone_delay', + 'type' +] + +RSE_ATTRS_BOOL = Literal[ + 'auto_approve_bytes', + 'auto_approve_files', + 'block_manual_approval', + 'greedyDeletion', + 'is_object_store', + 'restricted_read', + 'restricted_write', + 'skip_upload_stat', + 'staging_required', + 'strict_copy', + 'use_ipv4', + 'verify_checksum' +] + +SUPPORTED_SIGN_URL_SERVICES_LITERAL = Literal['gcs', 's3', 'swift'] +SUPPORTED_SIGN_URL_SERVICES = list(get_args(SUPPORTED_SIGN_URL_SERVICES_LITERAL)) diff --git a/lib/rucio/core/credential.py b/lib/rucio/core/credential.py index b89d654f07..97661f842f 100644 --- a/lib/rucio/core/credential.py +++ b/lib/rucio/core/credential.py @@ -17,7 +17,7 @@ import hmac import time from hashlib import sha1 -from typing import Literal, Optional +from typing import Optional from urllib.parse import urlencode, urlparse import boto3 @@ -27,7 +27,7 @@ from rucio.common.cache import MemcacheRegion from rucio.common.config import config_get, get_rse_credentials -from rucio.common.constants import RseAttr +from rucio.common.constants import RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS, RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, SUPPORTED_SIGN_URL_SERVICES, SUPPORTED_SIGN_URL_SERVICES_LITERAL, RseAttr from rucio.common.exception import UnsupportedOperation from rucio.core.monitor import MetricManager from rucio.core.rse import get_rse_attribute @@ -40,8 +40,8 @@ def get_signed_url( rse_id: str, - service: Literal['gsc', 's3', 'swift'], - operation: Literal['read', 'write', 'delete'], + service: SUPPORTED_SIGN_URL_SERVICES_LITERAL, + operation: RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, url: str, lifetime: Optional[int] = 600 ) -> str: @@ -60,10 +60,10 @@ def get_signed_url( global CREDS_GCS - if service not in ['gcs', 's3', 'swift']: + if service not in SUPPORTED_SIGN_URL_SERVICES: raise UnsupportedOperation('Service must be "gcs", "s3" or "swift"') - if operation not in ['read', 'write', 'delete']: + if operation not in RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS: raise UnsupportedOperation('Operation must be "read", "write", or "delete"') if url is None or url == '': @@ -84,10 +84,6 @@ def get_signed_url( components = urlparse(url) host = components.netloc - # select the correct operation - operations = {'read': 'GET', 'write': 'PUT', 'delete': 'DELETE'} - operation = operations[operation] - # special case to test signature, force epoch time if lifetime is None: lifetime = 0 @@ -100,8 +96,11 @@ def get_signed_url( # sign the path only path = components.path + # Map operations + operations = {'read': 'GET', 'write': 'PUT', 'delete': 'DELETE'} + # assemble message to sign - to_sign = "%s\n\n\n%s\n%s" % (operation, lifetime, path) + to_sign = "%s\n\n\n%s\n%s" % (operations[operation], lifetime, path) # create URL-capable signature # first character is always a '=', remove it @@ -213,7 +212,7 @@ def get_signed_url( else: swiftop = 'DELETE' - expires = int(time.time() + lifetime) + expires = int(time.time() + lifetime) # type: ignore (lifetime could be None) # create signed URL with METRICS.timer('signswift'): diff --git a/lib/rucio/core/rse.py b/lib/rucio/core/rse.py index 62f20e50d7..b1bacf023d 100644 --- a/lib/rucio/core/rse.py +++ b/lib/rucio/core/rse.py @@ -16,7 +16,7 @@ from datetime import datetime from io import StringIO from re import match -from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, TypeVar, Union, overload import sqlalchemy from dogpile.cache.api import NO_VALUE @@ -28,7 +28,7 @@ from rucio.common import exception, types, utils from rucio.common.cache import MemcacheRegion from rucio.common.config import get_lfn2pfn_algorithm_default -from rucio.common.constants import RSE_SUPPORTED_PROTOCOL_OPERATIONS, RseAttr +from rucio.common.constants import RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS, RSE_ATTRS_BOOL, RSE_ATTRS_STR, SUPPORTED_SIGN_URL_SERVICES_LITERAL, RseAttr from rucio.common.utils import CHECKSUM_KEY, GLOBALLY_SUPPORTED_CHECKSUMS, Availability from rucio.core.rse_counter import add_counter, get_counter from rucio.db.sqla import models @@ -886,7 +886,7 @@ def has_rse_attribute(rse_id, key, *, session: "Session"): @read_session -def get_rses_with_attribute(key, *, session: "Session"): +def get_rses_with_attribute(key, *, session: "Session") -> list[dict[str, Any]]: """ Return all RSEs with a certain attribute. @@ -959,6 +959,36 @@ def get_rses_with_attribute_value(key, value, vo='def', *, session: "Session"): return result +@overload +def get_rse_attribute(rse_id: str, key: Literal['sign_url'], use_cache: bool = True, *, session: "Session") -> Optional[SUPPORTED_SIGN_URL_SERVICES_LITERAL]: + ... + + +@overload +def get_rse_attribute(rse_id: str, key: Literal['sign_url'], use_cache: bool = True) -> Optional[SUPPORTED_SIGN_URL_SERVICES_LITERAL]: + ... + + +@overload +def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_STR', use_cache: bool = True) -> Optional[str]: + ... + + +@overload +def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_STR', use_cache: bool = True, *, session: "Session") -> Optional[str]: + ... + + +@overload +def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_BOOL', use_cache: bool = True) -> Optional[bool]: + ... + + +@overload +def get_rse_attribute(rse_id: str, key: 'RSE_ATTRS_BOOL', use_cache: bool = True, *, session: "Session") -> Optional[bool]: + ... + + @read_session def get_rse_attribute(rse_id: str, key: str, use_cache: bool = True, *, session: "Session") -> Optional[Union[str, bool]]: """ @@ -1290,7 +1320,7 @@ def add_protocol( if domain not in utils.rse_supported_protocol_domains(): raise exception.RSEProtocolDomainNotSupported(f"The protocol domain '{domain}' is not defined in the schema.") for op in parameter['domains'][domain]: - if op not in RSE_SUPPORTED_PROTOCOL_OPERATIONS: + if op not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: raise exception.RSEOperationNotSupported(f"Operation '{op}' not defined in schema.") op_name = op if op.startswith('third_party_copy') else f'{op}_{domain}'.lower() priority = parameter['domains'][domain][op] @@ -1414,7 +1444,7 @@ def _format_get_rse_protocols( 'verify_checksum': verify_checksum if verify_checksum is not None else True, 'volatile': _rse['volatile']} - for op in RSE_SUPPORTED_PROTOCOL_OPERATIONS: + for op in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: info['%s_protocol' % op] = 1 # 1 indicates the default protocol for row in db_protocols: @@ -1501,7 +1531,7 @@ def update_protocols( if domain not in utils.rse_supported_protocol_domains(): raise exception.RSEProtocolDomainNotSupported(f"The protocol domain '{domain}' is not defined in the schema.") for op in data['domains'][domain]: - if op not in RSE_SUPPORTED_PROTOCOL_OPERATIONS: + if op not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: raise exception.RSEOperationNotSupported(f"Operation '{op}' not defined in schema.") op_name = op if op.startswith('third_party_copy') else f'{op}_{domain}'.lower() priority = data['domains'][domain][op] @@ -1869,7 +1899,7 @@ def determine_scope_for_rse( # a base which should be removed from the prefix (in order for '/' to # mean the entire resource associated with that issuer). prefix = protocol['prefix'] - if base_path := get_rse_attribute(rse_id, RseAttr.OIDC_BASE_PATH): + if base_path := get_rse_attribute(rse_id, RseAttr.OIDC_BASE_PATH): # type: ignore (session parameter missing) prefix = prefix.removeprefix(base_path) filtered_prefixes.add(prefix) all_scopes = [f'{s}:{p}' for s in scopes for p in filtered_prefixes] + list(extra_scopes) diff --git a/lib/rucio/core/rse_expression_parser.py b/lib/rucio/core/rse_expression_parser.py index 0e740900e0..6efc205f32 100644 --- a/lib/rucio/core/rse_expression_parser.py +++ b/lib/rucio/core/rse_expression_parser.py @@ -15,7 +15,7 @@ import abc import re from hashlib import sha256 -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from dogpile.cache.api import NoValue @@ -286,7 +286,7 @@ def resolve_elements(self, session): """ Inherited from :py:func:`BaseExpressionElement.resolve_elements` """ - rse_list = get_rses_with_attribute(key=self.key, session=session) + rse_list: list[dict[str, Any]] = get_rses_with_attribute(key=self.key, session=session) if not rse_list: return (set(), {}) @@ -294,7 +294,7 @@ def resolve_elements(self, session): rse_dict = {} for rse in rse_list: try: - if float(get_rse_attribute(rse['id'], self.key, session=session)) < float(self.value): + if float(get_rse_attribute(rse['id'], self.key, session=session)) < float(self.value): # type: ignore (get_rse_attribute could return None) rse_dict[rse['id']] = rse output.append(rse['id']) except ValueError: @@ -321,7 +321,7 @@ def resolve_elements(self, session): """ Inherited from :py:func:`BaseExpressionElement.resolve_elements` """ - rse_list = get_rses_with_attribute(key=self.key, session=session) + rse_list: list[dict[str, Any]] = get_rses_with_attribute(key=self.key, session=session) if not rse_list: return (set(), {}) @@ -329,7 +329,7 @@ def resolve_elements(self, session): rse_dict = {} for rse in rse_list: try: - if float(get_rse_attribute(rse['id'], self.key, session=session)) > float(self.value): + if float(get_rse_attribute(rse['id'], self.key, session=session)) > float(self.value): # type: ignore (get_rse_attribute could return None) rse_dict[rse['id']] = rse output.append(rse['id']) except ValueError: diff --git a/lib/rucio/daemons/bb8/nuclei_background_rebalance.py b/lib/rucio/daemons/bb8/nuclei_background_rebalance.py index 08b7b9b73f..855f0b2032 100644 --- a/lib/rucio/daemons/bb8/nuclei_background_rebalance.py +++ b/lib/rucio/daemons/bb8/nuclei_background_rebalance.py @@ -61,7 +61,7 @@ def group_space(site: str) -> int: global_ratio = float(0) for rse in rses: site_name = get_rse_attribute(rse['id'], RseAttr.SITE) - rse['groupdisk'] = group_space(site_name) + rse['groupdisk'] = group_space(site_name) # type: ignore (site_name could be None) rse['primary'] = get_rse_usage(rse_id=rse['id'], source='rucio')[0]['used'] - get_rse_usage(rse_id=rse['id'], source='expired')[0]['used'] rse['primary'] += rse['groupdisk'] rse['secondary'] = get_rse_usage(rse_id=rse['id'], source='expired')[0]['used'] diff --git a/lib/rucio/daemons/bb8/t2_background_rebalance.py b/lib/rucio/daemons/bb8/t2_background_rebalance.py index 428affabe6..ac4f4e8635 100644 --- a/lib/rucio/daemons/bb8/t2_background_rebalance.py +++ b/lib/rucio/daemons/bb8/t2_background_rebalance.py @@ -61,7 +61,7 @@ def group_space(site: str) -> int: global_ratio = float(0) for rse in rses: site_name = get_rse_attribute(rse['id'], RseAttr.SITE) - rse['groupdisk'] = group_space(site_name) + rse['groupdisk'] = group_space(site_name) # type: ignore (site_name could be None) rse['primary'] = get_rse_usage(rse_id=rse['id'], source='rucio')[0]['used'] - get_rse_usage(rse_id=rse['id'], source='expired')[0]['used'] rse['primary'] += rse['groupdisk'] rse['secondary'] = get_rse_usage(rse_id=rse['id'], source='expired')[0]['used'] diff --git a/lib/rucio/daemons/rsedecommissioner/config.py b/lib/rucio/daemons/rsedecommissioner/config.py index bbb2fd53ee..e04fb1156c 100644 --- a/lib/rucio/daemons/rsedecommissioner/config.py +++ b/lib/rucio/daemons/rsedecommissioner/config.py @@ -75,7 +75,7 @@ def set_status( :param rse_id: RSE ID. :param status: RSE decommissioning status. """ - config = attr_to_config(get_rse_attribute(rse_id, RseAttr.DECOMMISSION)) + config = attr_to_config(get_rse_attribute(rse_id, RseAttr.DECOMMISSION)) # type: ignore (get_rse_attribute could return None) config['status'] = status # add_rse_attribute can handle updating existing entries too add_rse_attribute(rse_id, RseAttr.DECOMMISSION, config_to_attr(config)) diff --git a/lib/rucio/daemons/rsedecommissioner/rse_decommissioner.py b/lib/rucio/daemons/rsedecommissioner/rse_decommissioner.py index f780f7d895..05d0a17407 100644 --- a/lib/rucio/daemons/rsedecommissioner/rse_decommissioner.py +++ b/lib/rucio/daemons/rsedecommissioner/rse_decommissioner.py @@ -98,7 +98,7 @@ def run_once( # Get the decommission attribute (encodes the decommissioning config) attr = get_rse_attribute(rse['id'], RseAttr.DECOMMISSION) try: - config = attr_to_config(attr) + config = attr_to_config(attr) # type: ignore (attr could be None) except InvalidStatusName: logger(logging.ERROR, 'RSE %s has an invalid decommissioning status', rse['rse']) diff --git a/lib/rucio/gateway/credential.py b/lib/rucio/gateway/credential.py index 2e90494003..a8eda09470 100644 --- a/lib/rucio/gateway/credential.py +++ b/lib/rucio/gateway/credential.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING from rucio.common import exception from rucio.core import credential @@ -23,6 +23,8 @@ if TYPE_CHECKING: from sqlalchemy.orm import Session + from rucio.common.constants import RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL, SUPPORTED_SIGN_URL_SERVICES_LITERAL + @read_session def get_signed_url( @@ -30,8 +32,8 @@ def get_signed_url( appid: str, ip: str, rse: str, - service: Literal['gsc', 's3', 'swift'], - operation: Literal['read', 'write', 'delete'], + service: 'SUPPORTED_SIGN_URL_SERVICES_LITERAL', + operation: 'RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS_LITERAL', url: str, lifetime: int, vo: str = 'def', diff --git a/lib/rucio/rse/rsemanager.py b/lib/rucio/rse/rsemanager.py index 6f17d576bb..7e31c58979 100644 --- a/lib/rucio/rse/rsemanager.py +++ b/lib/rucio/rse/rsemanager.py @@ -21,7 +21,7 @@ from rucio.common import constants, exception, types, utils from rucio.common.config import config_get_int -from rucio.common.constants import RSE_SUPPORTED_PROTOCOL_OPERATIONS +from rucio.common.constants import RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS from rucio.common.constraints import STRING_TYPES from rucio.common.logging import formatted_logger from rucio.common.utils import GLOBALLY_SUPPORTED_CHECKSUMS, make_valid_did @@ -133,7 +133,7 @@ def _get_possible_protocols(rse_settings: types.RSESettingsDict, operation, sche def get_protocols_ordered(rse_settings: types.RSESettingsDict, operation, scheme=None, domain='wan', impl=None): - if operation not in RSE_SUPPORTED_PROTOCOL_OPERATIONS: + if operation not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: raise exception.RSEOperationNotSupported('Operation %s is not supported' % operation) if domain and domain not in utils.rse_supported_protocol_domains(): @@ -145,7 +145,7 @@ def get_protocols_ordered(rse_settings: types.RSESettingsDict, operation, scheme def select_protocol(rse_settings: types.RSESettingsDict, operation, scheme=None, domain='wan'): - if operation not in RSE_SUPPORTED_PROTOCOL_OPERATIONS: + if operation not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: raise exception.RSEOperationNotSupported('Operation %s is not supported' % operation) if domain and domain not in utils.rse_supported_protocol_domains(): @@ -173,7 +173,7 @@ def create_protocol(rse_settings: types.RSESettingsDict, operation, scheme=None, # Verify feasibility of Protocol operation = operation.lower() - if operation not in RSE_SUPPORTED_PROTOCOL_OPERATIONS: + if operation not in RSE_ALL_SUPPORTED_PROTOCOL_OPERATIONS: raise exception.RSEOperationNotSupported('Operation %s is not supported' % operation) if domain and domain not in utils.rse_supported_protocol_domains(): diff --git a/lib/rucio/web/rest/flaskapi/v1/credentials.py b/lib/rucio/web/rest/flaskapi/v1/credentials.py index 9c9923d7e7..e22e45a84b 100644 --- a/lib/rucio/web/rest/flaskapi/v1/credentials.py +++ b/lib/rucio/web/rest/flaskapi/v1/credentials.py @@ -17,6 +17,7 @@ from flask import Flask, request from werkzeug.datastructures import Headers +from rucio.common.constants import RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS, SUPPORTED_SIGN_URL_SERVICES from rucio.common.exception import CannotAuthenticate from rucio.gateway.credential import get_signed_url from rucio.web.rest.flaskapi.authenticated_bp import AuthenticatedBlueprint @@ -177,10 +178,10 @@ def get(self): return generate_http_error_flask(400, ValueError.__name__, 'Parameter "url" not found', headers=headers) url = request.args.get('url') - if service not in ['gcs', 's3', 'swift']: + if service not in SUPPORTED_SIGN_URL_SERVICES: return generate_http_error_flask(400, ValueError.__name__, 'Parameter "svc" must be either empty(=gcs), gcs, s3 or swift', headers=headers) - if operation not in ['read', 'write', 'delete']: + if operation not in RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS: return generate_http_error_flask(400, ValueError.__name__, 'Parameter "op" must be either empty(=read), read, write, or delete.', headers=headers) result = get_signed_url(account, appid, ip, rse=rse, service=service, operation=operation, url=url, lifetime=lifetime, vo=vo) diff --git a/tests/test_rse.py b/tests/test_rse.py index 0b6679b3b2..d7116ea30d 100644 --- a/tests/test_rse.py +++ b/tests/test_rse.py @@ -17,7 +17,7 @@ from rucio.client.replicaclient import ReplicaClient from rucio.common import exception -from rucio.common.constants import RseAttr +from rucio.common.constants import RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS, RseAttr from rucio.common.exception import Duplicate, InputValidationError, InvalidObject, ResourceTemporaryUnavailable, RSEAttributeNotFound, RSENotFound, RSEOperationNotSupported, RSEProtocolNotSupported from rucio.common.schema import get_schema_value from rucio.common.utils import CHECKSUM_KEY, GLOBALLY_SUPPORTED_CHECKSUMS @@ -995,7 +995,7 @@ def test_get_protocols_defaults(self, vo, rucio_client): rucio_client.add_protocol(protocol_rse, p) rse_attr = mgr.get_rse_info(rse=protocol_rse, vo=vo) - for op in ['delete', 'read', 'write']: + for op in RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS: # resp = rucio_client.get_protocols(protocol_rse, operation=op, default=True, protocol_domain='lan') p = mgr.select_protocol(rse_attr, op, domain='lan') print(p['scheme']) @@ -1005,7 +1005,7 @@ def test_get_protocols_defaults(self, vo, rucio_client): rucio_client.delete_protocols(protocol_rse, p['scheme']) rucio_client.delete_rse(protocol_rse) raise Exception('Unexpected protocols returned for %s: %s' % (op, p)) - for op in ['delete', 'read', 'write']: + for op in RSE_BASE_SUPPORTED_PROTOCOL_OPERATIONS: # resp = rucio_client.get_protocols(protocol_rse, operation=op, default=True, protocol_domain='wan') p = mgr.select_protocol(rse_attr, op, domain='wan') if ((op == 'delete') and (p['port'] != 17)) or ((op == 'read') and (p['port'] != 42)) or ((op == 'write') and (p['port'] != 19)):