From 7a247a8fd9dfe92f51764378a72bda22c9669733 Mon Sep 17 00:00:00 2001 From: Yidi Sprei Date: Sun, 24 Dec 2023 20:11:44 -0500 Subject: [PATCH] adding more advanced error handling --- .../oqs/built_in_functions.py | 47 ++++++++++++++++--- python_oqs_implementation/setup.py | 2 +- 2 files changed, 42 insertions(+), 7 deletions(-) diff --git a/python_oqs_implementation/oqs/built_in_functions.py b/python_oqs_implementation/oqs/built_in_functions.py index 1124cb7..84a4537 100644 --- a/python_oqs_implementation/oqs/built_in_functions.py +++ b/python_oqs_implementation/oqs/built_in_functions.py @@ -43,15 +43,39 @@ def bif_subtract(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | flo raise OQSTypeError(message=f"Cannot subtract '{get_oqs_type(a)}' by '{get_oqs_type(b)}'") -def bif_multiply(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: +def bif_multiply(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float | list | str: if len(node.args) < 2: raise OQSInvalidArgumentQuantityError( function_name=node.name, expected_min=2, expected_max=MAX_ARGS, actual=len(node.args) ) - result: int = 1 - for arg in node.args: - result *= interpreter.evaluate(arg) - return result + evaluated_args: list[any] = [interpreter.evaluate(arg) for arg in node.args] + completion: any = evaluated_args.pop(0) + for evaluated_arg in evaluated_args: + if isinstance(completion, (int, float)) and isinstance(evaluated_arg, (int, float)): + completion *= evaluated_arg + elif ( + isinstance(completion, list) and isinstance(evaluated_arg, int) + ) or (isinstance(completion, int) and isinstance(evaluated_arg, list)): + lst, multiplier = (completion, evaluated_arg) if isinstance( + completion, list + ) else (evaluated_arg, completion) + if multiplier > 1: + for i in range(multiplier - 1): + lst.extend(lst) + return lst + elif ( + isinstance(completion, str) and isinstance(evaluated_arg, int) + ) or (isinstance(completion, int) and isinstance(evaluated_arg, str)): + string, multiplier = (completion, evaluated_arg) if isinstance( + completion, str + ) else (evaluated_arg, completion) + return string * multiplier + + else: + raise OQSTypeError( + message=f"Cannot multiply '{get_oqs_type(completion)}' and '{get_oqs_type(evaluated_arg)}'" + ) + return completion def bif_divide(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: @@ -77,7 +101,10 @@ def bif_exponentiate(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) ) base, exponent = [interpreter.evaluate(arg) for arg in node.args] - return pow(base, exponent) + if isinstance(base, (int, float)) and isinstance(exponent, (int, float)): + return pow(base, exponent) + else: + raise OQSTypeError(message=f"Cannot exponentiate type '{get_oqs_type(a)}' by type '{get_oqs_type(b)}'.") def bif_modulo(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: @@ -86,6 +113,8 @@ def bif_modulo(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: function_name=node.name, expected_min=2, expected_max=2, actual=len(node.args) ) a, b = [interpreter.evaluate(arg) for arg in node.args] + if not isinstance(a, int) or not isinstance(b, int): + raise OQSTypeError(message=f"Cannot perform modulo on types '{get_oqs_type(a)}' and '{get_oqs_type(b)}'") return a % b @@ -95,6 +124,8 @@ def bif_integer(interpreter: 'OQSInterpreter', node: FunctionNode) -> int: function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) ) value: any = interpreter.evaluate(node.args[0]) + if not isinstance(value, (int, float, str)): + raise OQSTypeError(message=f"Cannot convert type '{get_oqs_type(value)}' to integer") return int(value) @@ -104,6 +135,8 @@ def bif_decimal(interpreter: 'OQSInterpreter', node: FunctionNode) -> float: function_name=node.name, expected_min=1, expected_max=1, actual=len(node.args) ) value: any = interpreter.evaluate(node.args[0]) + if not isinstance(value, (int, float, str)): + raise OQSTypeError(message=f"Cannot convert type '{get_oqs_type(value)}' to float") return float(value) @@ -194,6 +227,8 @@ def bif_max(interpreter: 'OQSInterpreter', node: FunctionNode) -> int | float: function_name=node.name, expected_min=1, expected_max=MAX_ARGS, actual=len(node.args) ) numbers: list[any] = [interpreter.evaluate(arg) for arg in node.args] + if not all(isinstance(item, (int, float)) for item in numbers): + raise OQSTypeError(message="All arguments must be numbers for 'max'") return max(numbers) diff --git a/python_oqs_implementation/setup.py b/python_oqs_implementation/setup.py index 48f7547..92d4d04 100644 --- a/python_oqs_implementation/setup.py +++ b/python_oqs_implementation/setup.py @@ -3,7 +3,7 @@ setup( name='oqs', - version='0.4.1', + version='0.4.2', packages=find_packages(include=['oqs', 'oqs.*']), description= "OQS (Open Quick Script) is a Python library for interpreting versatile expressions, supporting basic to advanced "