Skip to content

Commit

Permalink
Start adding type annotations
Browse files Browse the repository at this point in the history
... for importgroups and related code.
  • Loading branch information
stsnel committed Apr 30, 2024
1 parent 5960e3a commit 5057c7c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 35 deletions.
11 changes: 7 additions & 4 deletions yclienttools/common_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,27 @@
a config file, or if not defined a default value"""

import os
import sys
from typing import Optional

import yaml


def get_ca_file():
def get_ca_file() -> str:
return _get_parameter_with_default(
"ca_file", "/etc/irods/localhost_and_chain.crt")


def get_default_yoda_version():
def get_default_yoda_version() -> str:
return str(_get_parameter_with_default("default_yoda_version", "1.8"))


def _get_parameter_with_default(parameter, default_value):
def _get_parameter_with_default(parameter: str, default_value: str) -> str:
config_value = _get_parameter_from_config(parameter)
return config_value if config_value is not None else default_value


def _get_parameter_from_config(parameter):
def _get_parameter_from_config(parameter: str) -> Optional[str]:
config_filename = os.path.expanduser("~/.yodaclienttools.yml")
if not os.path.exists(config_filename):
return None
Expand All @@ -31,3 +33,4 @@ def _get_parameter_from_config(parameter):
except yaml.YAMLError as e:
print("Error occurred when opening configuration file.")
print(e)
sys.exit(1)
24 changes: 12 additions & 12 deletions yclienttools/common_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _string_list_to_list(self, s):
raise ValueError(
"Unable to convert string representation of list to list")

def call_uuGroupGetMembers(self, groupname):
def call_uuGroupGetMembers(self, groupname: str):
"""Returns list of group members"""
parms = OrderedDict([
('groupname', groupname)])
Expand All @@ -86,22 +86,22 @@ def call_uuGroupGetMembers(self, groupname):
"Group member list exceeds 1023 bytes")
return self._string_list_to_list(out)

def call_uuGroupUserRemove(self, groupname, user):
def call_uuGroupUserRemove(self, groupname: str, user: str):
"""Removes a user from a group"""
parms = OrderedDict([
('groupname', groupname),
('user', user)])
return self.call_rule('uuGroupUserRemove', parms, 2)

def call_uuGroupGetMemberType(self, groupname, user):
def call_uuGroupGetMemberType(self, groupname: str, user: str):
""":returns: member type of a group member"""
parms = OrderedDict([
('groupname', groupname),
('user', user)])
return self.call_rule('uuGroupGetMemberType', parms, 1)[0]

def call_uuGroupUserAddByOtherCreator(
self, groupname, username, creator_user, creator_zone):
self, groupname: str, username: str, creator_user: str, creator_zone: str):
"""Adds user to group on the behalf of a creator user.
:param: groupname
Expand All @@ -117,7 +117,7 @@ def call_uuGroupUserAddByOtherCreator(
('creatorZone', creator_zone)])
return self.call_rule('uuGroupUserAdd', parms, 2)

def call_uuGroupUserAdd(self, groupname, username):
def call_uuGroupUserAdd(self, groupname: str, username: str):
"""Adds user to group.
:param: groupname
Expand All @@ -129,7 +129,7 @@ def call_uuGroupUserAdd(self, groupname, username):
('username', username)])
return self.call_rule('uuGroupUserAdd', parms, 2)

def call_uuGroupUserChangeRole(self, groupname, username, newrole):
def call_uuGroupUserChangeRole(self, groupname: str, username: str, newrole: str):
"""Change role of user in group
:param groupname: name of group
Expand All @@ -143,7 +143,7 @@ def call_uuGroupUserChangeRole(self, groupname, username, newrole):
('newrole', newrole)])
return self.call_rule('uuGroupUserChangeRole', parms, 2)

def call_uuGroupExists(self, groupname):
def call_uuGroupExists(self, groupname: str):
"""Check whether group name exists on Yoda
:param groupname: name of group
Expand All @@ -153,7 +153,7 @@ def call_uuGroupExists(self, groupname):
[out] = self.call_rule('uuGroupExists', parms, 1)
return out == 'true'

def call_uuUserExists(self, username):
def call_uuUserExists(self, username: str):
"""Check whether user name exists on Yoda
:param username: name of user
Expand All @@ -163,8 +163,8 @@ def call_uuUserExists(self, username):
[out] = self.call_rule('uuUserExists', parms, 1)
return out == 'true'

def call_uuGroupAdd(self, groupname, category,
subcategory, description, classification, schema_id='default-2', expiration_date=''):
def call_uuGroupAdd(self, groupname: str, category: str,
subcategory: str, description: str, classification, schema_id: str = 'default-2', expiration_date: str = ''):
"""Adds a group
:param groupname: name of group
Expand Down Expand Up @@ -199,7 +199,7 @@ def call_uuGroupAdd(self, groupname, category,

return self.call_rule('uuGroupAdd', parms, 2)

def call_uuGroupModify(self, groupname, property, value):
def call_uuGroupModify(self, groupname: str, property: str, value: str):
"""Modifies one property of a group
:param groupname: name of group
Expand All @@ -213,7 +213,7 @@ def call_uuGroupModify(self, groupname, property, value):
('value', value)])
return self.call_rule('uuGroupModify', parms, 2)

def call_uuGroupRemove(self, groupname):
def call_uuGroupRemove(self, groupname: str):
"""Removes an empty group
:param groupname: name of group
Expand Down
36 changes: 18 additions & 18 deletions yclienttools/importgroups.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Based on yoda-batch-add script by Ton Smeele


def parse_csv_file(input_file, args, yoda_version):
def parse_csv_file(input_file: str, args: argparse.Namespace, yoda_version: str) -> list:
extracted_data = []

with open(input_file, mode="r", encoding="utf-8-sig") as csv_file:
Expand Down Expand Up @@ -62,7 +62,7 @@ def parse_csv_file(input_file, args, yoda_version):
# keys are the column names, items are the list of items
for line in reader:
row_number += 1
d = {}
d: dict = {}
for j in range(len(line)):
item = line[j].strip()
if len(item):
Expand All @@ -82,30 +82,30 @@ def parse_csv_file(input_file, args, yoda_version):
return extracted_data


def _get_csv_possible_labels(yoda_version):
def _get_csv_possible_labels(yoda_version: str) -> list[str]:
if yoda_version in ('1.7', '1.8'):
return ['category', 'subcategory', 'groupname', 'viewer', 'member', 'manager']
else:
return ['category', 'subcategory', 'groupname', 'viewer', 'member', 'manager', 'expiration_date', 'schema_id']


def _get_csv_required_labels():
def _get_csv_required_labels() -> list[str]:
return ['category', 'subcategory', 'groupname']


def _get_csv_1_9_exclusive_labels():
def _get_csv_1_9_exclusive_labels() -> list[str]:
"""Returns labels that can only appear with yoda version 1.9 and higher."""
return ['expiration_date', 'schema_id']


def _get_csv_predefined_labels(yoda_version):
def _get_csv_predefined_labels(yoda_version: str) -> list[str]:
if yoda_version in ('1.7', '1.8'):
return ['category', 'subcategory', 'groupname']
else:
return ['category', 'subcategory', 'groupname', 'expiration_date', 'schema_id']


def _get_duplicate_columns(fields_list, yoda_version):
def _get_duplicate_columns(fields_list: list[str], yoda_version: str) -> set[str]:
""" Only checks columns that cannot have duplicates """
fields_seen = set()
duplicate_fields = set()
Expand All @@ -120,12 +120,12 @@ def _get_duplicate_columns(fields_list, yoda_version):
return duplicate_fields


def _get_duplicate_groups(row_data):
def _get_duplicate_groups(row_data: list) -> list[str]:
group_names = list(map(lambda r: r[2], row_data))
return list(unique_everseen(duplicates(group_names)))


def _process_csv_line(line, args, yoda_version):
def _process_csv_line(line: dict, args: argparse.Namespace, yoda_version: str) -> tuple:
if ('category' not in line or not len(line['category'])
or 'subcategory' not in line or not len(line['subcategory'])
or 'groupname' not in line or not len(line['groupname'])):
Expand Down Expand Up @@ -185,7 +185,7 @@ def _process_csv_line(line, args, yoda_version):
return row_data, None


def _are_roles_equivalent(a, b):
def _are_roles_equivalent(a: str, b: str) -> bool:
"""Checks whether two roles are equivalent. Needed because Yoda and Yoda-clienttools
use slightly different names for the roles."""
r_role_names = ["viewer", "reader"]
Expand All @@ -201,7 +201,7 @@ def _are_roles_equivalent(a, b):
return False


def validate_data(rule_interface, args, data):
def validate_data(rule_interface: RuleInterface, args: argparse.Namespace, data: list) -> list[str]:
errors = []
for (category, subcategory, groupname, managers, members, viewers, schema_id, expiration_date) in data:
if rule_interface.call_uuGroupExists(groupname) and not args.allow_update:
Expand All @@ -219,7 +219,7 @@ def validate_data(rule_interface, args, data):
return errors


def apply_data(rule_interface, args, data):
def apply_data(rule_interface: RuleInterface, args: argparse.Namespace, data: list) -> None:
for (category, subcategory, groupname, managers, members, viewers, schema_id, expiration_date) in data:
new_group = False

Expand Down Expand Up @@ -361,7 +361,7 @@ def apply_data(rule_interface, args, data):
print("Status: {} , Message: {}".format(status, msg))


def print_parsed_data(data):
def print_parsed_data(data: list) -> None:
print('Parsed data:')
print()

Expand All @@ -381,7 +381,7 @@ def print_parsed_data(data):
print()


def entry():
def entry() -> None:
'''Entry point'''
args = _get_args()
yoda_version = args.yoda_version if args.yoda_version is not None else common_config.get_default_yoda_version()
Expand Down Expand Up @@ -429,7 +429,7 @@ def entry():
session.cleanup()


def _get_args():
def _get_args() -> argparse.Namespace:
'''Parse command line arguments'''

parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -460,7 +460,7 @@ def _get_args():
return parser.parse_args()


def _get_format_help_text():
def _get_format_help_text() -> str:
return '''
The CSV file is expected to include the following labels in its header (the first row):
'category' = category for the group
Expand Down Expand Up @@ -493,12 +493,12 @@ def _get_format_help_text():
'''


def _exit_with_error(message):
def _exit_with_error(message: str) -> None:
print("Error: {}".format(message), file=sys.stderr)
sys.exit(1)


def _exit_with_validation_errors(errors):
def _exit_with_validation_errors(errors: list[str]) -> None:
for error in errors:
print("Validation error: {}".format(error), file=sys.stderr)
sys.exit(1)
2 changes: 1 addition & 1 deletion yclienttools/yoda_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from backports.functools_lru_cache import lru_cache # type: ignore[no-redef]


def is_valid_username(username, no_validate_domains):
def is_valid_username(username: str, no_validate_domains: bool):
"""Is this name a valid username
Returns whether valid and error
"""
Expand Down

0 comments on commit 5057c7c

Please sign in to comment.