From 430e7c13367689693c7a962a0a1b765e5df56290 Mon Sep 17 00:00:00 2001 From: yidis Date: Tue, 26 Dec 2023 17:08:54 -0500 Subject: [PATCH] cleaning up functions --- .../oqs/built_in_functions.py | 290 ++++-------------- python_oqs_implementation/oqs/utils/checks.py | 11 + 2 files changed, 71 insertions(+), 230 deletions(-) create mode 100644 python_oqs_implementation/oqs/utils/checks.py diff --git a/python_oqs_implementation/oqs/built_in_functions.py b/python_oqs_implementation/oqs/built_in_functions.py index eec871e..cdd66b5 100644 --- a/python_oqs_implementation/oqs/built_in_functions.py +++ b/python_oqs_implementation/oqs/built_in_functions.py @@ -13,6 +13,7 @@ get_error_name_mapping ) from .nodes import (FunctionNode, ASTNode) +from .utils.checks import ensure_function_arg_quantity from .utils.conversion import OQSJSONEncoder from .utils.shortcuts import (get_oqs_type, is_oqs_instance) @@ -20,10 +21,7 @@ def bif_add( interpreter: 'OQSInterpreter', node: FunctionNode ) -> int | float | list | str | dict | datetime.datetime | datetime.date | datetime.time | datetime.timedelta: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: list[any] = [interpreter.evaluate(arg) for arg in node.args] completion: any = evaluated_args.pop(0) for evaluated_arg in evaluated_args: @@ -57,7 +55,7 @@ def bif_add( time_seconds: int = (completion.hour * 3600) + (completion.minute * 60) + completion.second total_seconds: int = time_seconds + evaluated_arg.seconds new_time: datetime.time = (datetime.datetime.min + datetime.timedelta(seconds=total_seconds)).time() - completion: daetime.time = new_time + completion: datetime.time = new_time else: raise OQSTypeError(message=f"Cannot add '{get_oqs_type(completion)}' and '{get_oqs_type(evaluated_arg)}'") return completion @@ -66,10 +64,7 @@ def bif_add( def bif_subtract( interpreter: 'OQSInterpreter', node: FunctionNode ) -> int | float | list | str | datetime.datetime | datetime.date | datetime.time | datetime.timedelta: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) a, b = [interpreter.evaluate(arg) for arg in node.args] if isinstance(a, (int, float)) and isinstance(b, (int, float)): return a - b @@ -99,10 +94,7 @@ def bif_subtract( def bif_multiply(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float | list | str: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: list[any] = [interpreter.evaluate(arg) for arg in node.args] completion: any = evaluated_args.pop(0) for evaluated_arg in evaluated_args: @@ -134,10 +126,7 @@ def bif_multiply(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | flo def bif_divide(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) a, b = [interpreter.evaluate(arg) for arg in node.args] if b == 0: raise OQSDivisionByZeroError() @@ -151,10 +140,7 @@ def bif_divide(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float def bif_exponentiate(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float | complex: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) base, exponent = [interpreter.evaluate(arg) for arg in node.args] if isinstance(base, (int, float)) and isinstance(exponent, (int, float)): return pow(base, exponent) @@ -165,10 +151,7 @@ def bif_exponentiate(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | def bif_modulo(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) a, b = [interpreter.evaluate(arg) for arg in node.args] if not isinstance(a, int) or not isinstance(b, int): raise OQSTypeError(message=f"Cannot perform modulo on types '{get_oqs_type(a)}' and '{get_oqs_type(b)}'") @@ -176,10 +159,7 @@ def bif_modulo(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_less_than(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -215,10 +195,7 @@ def bif_less_than(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_greater_than(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -254,10 +231,7 @@ def bif_greater_than(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_less_than_or_equal(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -293,10 +267,7 @@ def bif_less_than_or_equal(interpreter: 'OQSInterpreter', node: FunctionNode) -> def bif_greater_than_or_equal(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -332,10 +303,7 @@ def bif_greater_than_or_equal(interpreter: 'OQSInterpreter', node: FunctionNode) def bif_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -355,10 +323,7 @@ def bif_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_not_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -378,10 +343,7 @@ def bif_not_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_strictly_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -401,10 +363,7 @@ def bif_strictly_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> in def bif_strictly_not_equals(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) evaluated_args: dict[int, any] = {} for i, arg in enumerate(node.args): if len(node.args) > i + 1: @@ -424,10 +383,7 @@ def bif_strictly_not_equals(interpreter: 'OQSInterpreter', node: FunctionNode) - def bif_and(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: - if len(node.args) < 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) for arg in node.args: if not interpreter.evaluate(arg): return False @@ -435,10 +391,7 @@ def bif_and(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: def bif_or(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: - if len(node.args) < 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) for arg in node.args: if interpreter.evaluate(arg): return True @@ -446,19 +399,13 @@ def bif_or(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: def bif_not(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) value: any = interpreter.evaluate(node.args[0]) return not bool(value) def bif_integer(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) value: any = interpreter.evaluate(node.args[0]) if not isinstance(value, (int, float, str)): raise OQSTypeError(message=f"Cannot convert type '{get_oqs_type(value)}' to integer") @@ -466,10 +413,7 @@ def bif_integer(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_decimal(interpreter: 'OQSInterpreter', node: FunctionNode) -> float: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) value: any = interpreter.evaluate(node.args[0]) if not isinstance(value, (int, float, str)): raise OQSTypeError(message=f"Cannot convert type '{get_oqs_type(value)}' to float") @@ -477,10 +421,7 @@ def bif_decimal(interpreter: 'OQSInterpreter', node: FunctionNode) -> float: def bif_string(interpreter: 'OQSInterpreter', node: FunctionNode) -> str: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) value: any = interpreter.evaluate(node.args[0]) return json.dumps(value, cls=OQSJSONEncoder) @@ -505,19 +446,13 @@ def bif_kvs(interpreter: 'OQSInterpreter', node: FunctionNode) -> dict[str, any] def bif_boolean(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) value: any = interpreter.evaluate(node.args[0]) return bool(value) def bif_keys(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[str]: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) kvs: any = interpreter.evaluate(node.args[0]) if not isinstance(kvs, dict): raise OQSTypeError(message='Argument must be a KVS') @@ -525,10 +460,7 @@ def bif_keys(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[str]: def bif_values(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) kvs: any = interpreter.evaluate(node.args[0]) if not isinstance(kvs, dict): raise OQSTypeError(message='Argument must be a KVS') @@ -536,10 +468,7 @@ def bif_values(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: def bif_unique(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) lst: any = interpreter.evaluate(node.args[0]) if not isinstance(lst, list): raise OQSTypeError(message='Argument must be a list') @@ -547,10 +476,7 @@ def bif_unique(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: def bif_reverse(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) lst: any = interpreter.evaluate(node.args[0]) if not isinstance(lst, list): raise OQSTypeError(message='Argument must be a list') @@ -558,10 +484,7 @@ def bif_reverse(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: def bif_max(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: - if len(node.args) < 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1) numbers: list[any] = [interpreter.evaluate(arg) for arg in node.args] if not all(isinstance(item, (int, float)) for item in numbers): raise OQSTypeError(message="All arguments must be numbers for 'max'") @@ -569,19 +492,13 @@ def bif_max(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: def bif_min(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: - if len(node.args) < 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1) numbers: list[any] = [interpreter.evaluate(arg) for arg in node.args] return min(numbers) def bif_sum(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) lst: any = interpreter.evaluate(node.args[0]) if not isinstance(lst, list) or not all(isinstance(item, (int, float)) for item in lst): raise OQSTypeError(message='Argument must be a list of numbers') @@ -589,10 +506,7 @@ def bif_sum(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: def bif_length(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) value: any = interpreter.evaluate(node.args[0]) if not isinstance(value, (str, list, dict)): raise OQSTypeError(message="Argument must be a string, list or KVS") @@ -600,10 +514,7 @@ def bif_length(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: def bif_append(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) lst, item = [interpreter.evaluate(arg) for arg in node.args] if not isinstance(lst, list): raise OQSTypeError(message="First argument must be a list") @@ -612,10 +523,7 @@ def bif_append(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: def bif_update(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | dict[str, any]: - if len(node.args) != 3: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3, max_args=3) container, key_or_index, value = [interpreter.evaluate(arg) for arg in node.args] if isinstance(container, list): if not isinstance(key_or_index, int): @@ -634,10 +542,7 @@ def bif_update(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | def bif_remove_item(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | dict[str, any]: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) container, item = [interpreter.evaluate(arg) for arg in node.args[:2]] max_occurrences: int = interpreter.evaluate(node.args[2]) if len(node.args) == 3 else MAX_ARGS if not isinstance(max_occurrences, int): @@ -659,10 +564,7 @@ def bif_remove_item(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[a def bif_remove(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | dict[str, any]: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) container, key_or_index = [interpreter.evaluate(arg) for arg in node.args] if isinstance(container, list): if not isinstance(key_or_index, int): @@ -680,10 +582,7 @@ def bif_remove(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | def bif_access(interpreter: 'OQSInterpreter', node: FunctionNode): - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) container, key_or_index = [interpreter.evaluate(arg) for arg in node.args[:2]] default_value = interpreter.evaluate(node.args[2]) if len(node.args) == 3 else None if isinstance(container, list): @@ -699,10 +598,7 @@ def bif_access(interpreter: 'OQSInterpreter', node: FunctionNode): def bif_if(interpreter: 'OQSInterpreter', node: FunctionNode) -> any: - if len(node.args) < 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2) for i in range(0, len(node.args) - 1, 2): condition: any = interpreter.evaluate(node.args[i]) if condition: @@ -713,19 +609,13 @@ def bif_if(interpreter: 'OQSInterpreter', node: FunctionNode) -> any: def bif_type(interpreter: 'OQSInterpreter', node: FunctionNode) -> str: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) argument: any = interpreter.evaluate(node.args[0]) return get_oqs_type(argument) def bif_is_type(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) value, expected_type = [interpreter.evaluate(arg) for arg in node.args] if not isinstance(expected_type, str): raise OQSTypeError(message=f"Second argument must be a String. Instead got '{get_oqs_type(expected_type)}'.") @@ -733,10 +623,7 @@ def bif_is_type(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: def bif_try(interpreter: 'OQSInterpreter', node: FunctionNode) -> any: - if len(node.args) < 3: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=MAX_ARGS, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3) if len(node.args) % 2 == 0: raise OQSFunctionEvaluationError( function_name=node.name, message=f"Expected an odd amount of input arguments. Instead got {len(node.args)}." @@ -759,10 +646,7 @@ def bif_try(interpreter: 'OQSInterpreter', node: FunctionNode) -> any: def bif_range(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[int]: - if not (1 < len(node.args) < 3): - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=3) start: int = 0 step: int = 1 stop: int = 1 @@ -782,10 +666,7 @@ def bif_range(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[int]: def bif_for_or_map(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) != 3: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3, max_args=3) looping_list: any = interpreter.evaluate(node.args[0]) variable_name: any = interpreter.evaluate(node.args[1]) expression: ASTNode = node.args[2] @@ -805,10 +686,7 @@ def bif_for_or_map(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[an def bif_raise(interpreter: 'OQSInterpreter', node: FunctionNode) -> any: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) error_name, error_message = [interpreter.evaluate(arg) for arg in node.args] if not isinstance(error_name, str): raise OQSTypeError( @@ -831,10 +709,7 @@ def __init__(self): def bif_filter(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | dict[str, any]: - if len(node.args) != 3: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3, max_args=3) collection, unevaluated_variable_name, predicate = node.args collection_value: any = interpreter.evaluate(collection) if not isinstance(collection_value, (list, dict)): @@ -863,10 +738,7 @@ def bif_filter(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | def bif_sort(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) < 3 or len(node.args) > 4: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=4, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3, max_args=4) collection, unevaluated_variable_name, key_expression = node.args[:3] descending: bool = interpreter.evaluate(node.args[3]) if len(node.args) == 4 else False @@ -893,10 +765,7 @@ def evaluate_expression_with_variable(item: any) -> any: def bif_flatten(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any]: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) list_to_flatten: any = interpreter.evaluate(node.args[0]) if not isinstance(list_to_flatten, list): raise OQSTypeError( @@ -916,10 +785,7 @@ def flatten(lst): def bif_slice(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | str: - if len(node.args) < 2 or len(node.args) > 3: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=3) collection, start = [interpreter.evaluate(arg) for arg in node.args[:2]] end: any = interpreter.evaluate(node.args[2]) if len(node.args) == 3 else None @@ -938,10 +804,7 @@ def bif_slice(interpreter: 'OQSInterpreter', node: FunctionNode) -> list[any] | def bif_in(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) value, collection = [interpreter.evaluate(arg) for arg in node.args] if isinstance(collection, list): return value in collection @@ -955,10 +818,7 @@ def bif_in(interpreter: 'OQSInterpreter', node: FunctionNode) -> bool: def bif_date(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.date: - if len(node.args) != 3: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3, max_args=3) year, month, day = [interpreter.evaluate(arg) for arg in node.args] if not all(isinstance(i, int) for i in [year, month, day]): raise OQSTypeError( @@ -972,10 +832,7 @@ def bif_date(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.date def bif_time(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.time: - if not (3 <= len(node.args) <= 4): - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=3, expected_max=4, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=3, max_args=4) hour, minute, second, *ms = [interpreter.evaluate(arg) for arg in node.args] if not all(isinstance(i, int) for i in [hour, minute, second] + ms): raise OQSTypeError( @@ -990,10 +847,7 @@ def bif_time(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.time def bif_datetime(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.datetime: - if not (6 <= len(node.args) <= 7): - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=6, expected_max=7, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=6, max_args=7) year, month, day, hour, minute, second, *ms = [interpreter.evaluate(arg) for arg in node.args] if not all(isinstance(i, int) for i in [year, month, day, hour, minute, second] + ms): raise OQSTypeError( @@ -1008,10 +862,7 @@ def bif_datetime(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime. def bif_duration(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.timedelta: - if not (4 <= len(node.args) <= 5): - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=4, expected_max=5, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=4, max_args=5) days, hours, minutes, seconds, *ms = [interpreter.evaluate(arg) for arg in node.args] if not all(isinstance(i, int) for i in [days, hours, minutes, seconds] + ms): raise OQSTypeError( @@ -1023,36 +874,24 @@ def bif_duration(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime. def bif_now(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.datetime: - if len(node.args) != 0: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=0, expected_max=0, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=0, max_args=0) return datetime.datetime.utcnow() def bif_today(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.date: - if len(node.args) != 0: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=0, expected_max=0, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=0, max_args=0) return datetime.date.today() def bif_time_now(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.time: - if len(node.args) != 0: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=0, expected_max=0, actual=len(node.args) - ) - return datetime.utcnow().time() + ensure_function_arg_quantity(node=node, min_args=0, max_args=0) + return datetime.datetime.utcnow().time() def bif_parse_temporal( interpreter: 'OQSInterpreter', node: FunctionNode ) -> datetime.datetime | datetime.date | datetime.time | datetime.timedelta: - if not (2 <= len(node.args) <= 3): - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=3, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=3) string, temporal_type, *optional_format = [interpreter.evaluate(arg) for arg in node.args] if not all(isinstance(arg, str) for arg in [string, temporal_type] + optional_format): raise OQSTypeError( @@ -1088,10 +927,7 @@ def bif_parse_temporal( def bif_format_temporal(interpreter: 'OQSInterpreter', node: FunctionNode) -> str: - if len(node.args) != 2: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=2, max_args=2) temporal, format_str = [interpreter.evaluate(arg) for arg in node.args] if not isinstance(format_str, str): raise OQSTypeError( @@ -1109,10 +945,7 @@ def bif_format_temporal(interpreter: 'OQSInterpreter', node: FunctionNode) -> st def bif_extract_date(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.date: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) datetime_obj: datetime.datetime = interpreter.evaluate(node.args[0]) if not isinstance(datetime_obj, datetime.datetime): raise OQSTypeError(message=f"Argument must be a DateTime type. Instead got '{get_oqs_type(datetime_obj)}'.") @@ -1120,10 +953,7 @@ def bif_extract_date(interpreter: 'OQSInterpreter', node: FunctionNode) -> datet def bif_extract_time(interpreter: 'OQSInterpreter', node: FunctionNode) -> datetime.time: - if len(node.args) != 1: - raise OQSInvalidArgumentQuantityError( - function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) - ) + ensure_function_arg_quantity(node=node, min_args=1, max_args=1) datetime_obj: datetime.datetime = interpreter.evaluate(node.args[0]) if not isinstance(datetime_obj, datetime.datetime): raise OQSTypeError(message=f"Argument must be a DateTime type. Instead got '{get_oqs_type(datetime_obj)}'.") diff --git a/python_oqs_implementation/oqs/utils/checks.py b/python_oqs_implementation/oqs/utils/checks.py new file mode 100644 index 0000000..fb0c693 --- /dev/null +++ b/python_oqs_implementation/oqs/utils/checks.py @@ -0,0 +1,11 @@ +from ..errors import OQSInvalidArgumentQuantityError +from ..nodes import FunctionNode +from ..constants.values import MAX_ARGS + + +def ensure_function_arg_quantity(node: FunctionNode, min_args: int, max_args: int = MAX_ARGS) -> None: + arg_count: int = len(node.args) + if not (min_args <= arg_count <= max_args): + raise OQSInvalidArgumentQuantityError( + function_name=node.name, expected_min=min_args, expected_max=max_args, actual=arg_count + )