Skip to content

Commit

Permalink
adding the ability to add functions
Browse files Browse the repository at this point in the history
  • Loading branch information
infuzu-yidisprei committed Dec 24, 2023
1 parent 4802c25 commit 97c18ad
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
26 changes: 21 additions & 5 deletions python_oqs_implementation/oqs/engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import re
import time
from typing import Callable
from .interpreter import OQSInterpreter
from .errors import OQSBaseError
from .utils.shortcuts import get_oqs_type
Expand All @@ -13,25 +14,36 @@ def __init__(self, expression: str, variables: dict[str, any] | None = None, str


def evaluate_expression(
expression: str | ExpressionInput, variables: dict[str, any] | None = None, string_embedded: bool = False
expression: str | ExpressionInput,
variables: dict[str, any] | None = None,
string_embedded: bool = False,
additional_functions: list[tuple[str, Callable]] | None = None
) -> dict[str, any]:
if isinstance(expression, ExpressionInput):
variables: dict[str, any] | None = expression.variables
string_embedded: bool = expression.string_embedded
expression: str = expression.expression
if additional_functions is None:
additional_functions: list[tuple[str, Callable]] = []
try:
if string_embedded:
def replace_embedded(match: re.match):
embedded_expr: str = match.group(1)
embedded_result: dict[str, any] = evaluate_expression(
expression=embedded_expr, variables=variables, string_embedded=False
expression=embedded_expr,
variables=variables,
string_embedded=False,
additional_functions=additional_functions
)
return str(embedded_result["results"]["value"])

result_expression: str = re.sub(r'<\{(.*?)\}>', replace_embedded, expression)
return {"results": {"value": result_expression, "type": "String"}}

result: any = OQSInterpreter(expression=expression, variables=variables).results()
interpreter: OQSInterpreter = OQSInterpreter(expression=expression, variables=variables)
for function_name, function in additional_functions:
interpreter.add_additional_function(function_name=function_name, function=function)
result: any = interpreter.results()

return {"results": {"value": result, "type": get_oqs_type(result)}}
except OQSBaseError as e:
Expand All @@ -51,7 +63,8 @@ def oqs_engine(
string_embedded: bool = False,
report_usage: bool = False,
evaluate_multiple: bool = False,
expression_inputs: list[ExpressionInput] = None
expression_inputs: list[ExpressionInput] = None,
additional_functions: list[tuple[str, Callable]] | None = None
) -> dict[str, any]:
start_cpu_time: int = time.process_time_ns()
if evaluate_multiple:
Expand All @@ -63,7 +76,10 @@ def oqs_engine(
results: dict[str, any] = {"results": expression_results}
else:
results: dict[str, any] = evaluate_expression(
expression=expression, variables=variables, string_embedded=string_embedded
expression=expression,
variables=variables,
string_embedded=string_embedded,
additional_functions=additional_functions
)
if report_usage:
results["cpu_time_ns"] = time.process_time_ns() - start_cpu_time
Expand Down
3 changes: 3 additions & 0 deletions python_oqs_implementation/oqs/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(self, expression: str, variables: dict[str, any] | None = None) ->
self.original_ast: ASTNode = self.parser.parse(expression=self.original_expression)
self.variables: dict[str, any] = variables if variables else {}

def add_additional_function(self, function_name: str, function: Callable):
self.FUNCTIONS[function_name] = function

def results(self) -> any:
return self.evaluate(self.original_ast)

Expand Down

0 comments on commit 97c18ad

Please sign in to comment.