diff --git a/dace/frontend/fortran/ast_transforms.py b/dace/frontend/fortran/ast_transforms.py index 0c96560fba..57508d6d90 100644 --- a/dace/frontend/fortran/ast_transforms.py +++ b/dace/frontend/fortran/ast_transforms.py @@ -184,7 +184,7 @@ def __init__(self, funcs=None): from dace.frontend.fortran.intrinsics import FortranIntrinsics self.excepted_funcs = [ - "malloc", "exp", "pow", "sqrt", "cbrt", "max", "abs", "min", "__dace_sign", "tanh", + "malloc", "pow", "cbrt", "__dace_sign", "tanh", "atan2", "__dace_epsilon", *FortranIntrinsics.function_names() ] @@ -220,7 +220,7 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics if not stop and node.name.name not in [ - "malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() + "malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions() ]: self.nodes.append(node) return self.generic_visit(node) @@ -241,7 +241,7 @@ def __init__(self, count=0): def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): from dace.frontend.fortran.intrinsics import FortranIntrinsics - if node.name.name in ["malloc", "exp", "pow", "sqrt", "cbrt", "max", "min", "abs", "tanh", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]: + if node.name.name in ["malloc", "pow", "cbrt", "__dace_epsilon", *FortranIntrinsics.call_extraction_exemptions()]: return self.generic_visit(node) if hasattr(node, "subroutine"): if node.subroutine is True: @@ -251,6 +251,11 @@ def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): else: self.count = self.count + 1 tmp = self.count + + for i, arg in enumerate(node.args): + # Ensure we allow to extract function calls from arguments + node.args[i] = self.visit(arg) + return ast_internal_classes.Name_Node(name="tmp_call_" + str(tmp - 1)) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): @@ -263,9 +268,13 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No for i in res: if i == child: res.pop(res.index(i)) - temp = self.count if res is not None: - for i in range(0, len(res)): + # Variables are counted from 0...end, starting from main node, to all calls nested + # in main node arguments. + # However, we need to define nested ones first. + # We go in reverse order, counting from end-1 to 0. + temp = self.count + len(res) - 1 + for i in reversed(range(0, len(res))): newbody.append( ast_internal_classes.Decl_Stmt_Node(vardecl=[ @@ -282,7 +291,7 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No type=res[i].type), rval=res[i], line_number=child.line_number)) - temp = temp + 1 + temp = temp - 1 if isinstance(child, ast_internal_classes.Call_Expr_Node): new_args = [] if hasattr(child, "args"): @@ -368,7 +377,8 @@ def __init__(self): self.nodes: List[ast_internal_classes.Array_Subscript_Node] = [] def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): - if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]: + from dace.frontend.fortran.intrinsics import FortranIntrinsics + if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]: return self.generic_visit(node) else: return @@ -401,7 +411,8 @@ def __init__(self, ast: ast_internal_classes.FNode, normalize_offsets: bool = Fa self.scope_vars.visit(ast) def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): - if node.name.name in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh"]: + from dace.frontend.fortran.intrinsics import FortranIntrinsics + if node.name.name in ["pow", "atan2", "tanh", *FortranIntrinsics.retained_function_names()]: return self.generic_visit(node) else: return node diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 52344c141f..1cdecc99a8 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -818,7 +818,8 @@ def binop2sdfg(self, node: ast_internal_classes.BinOp_Node, sdfg: SDFG, cfg: Con calls.visit(node) if len(calls.nodes) == 1: augmented_call = calls.nodes[0] - if augmented_call.name.name not in ["sqrt", "exp", "pow", "max", "min", "abs", "tanh", "__dace_epsilon"]: + from dace.frontend.fortran.intrinsics import FortranIntrinsics + if augmented_call.name.name not in ["pow", "atan2", "tanh", "__dace_epsilon", *FortranIntrinsics.retained_function_names()]: augmented_call.args.append(node.lval) augmented_call.hasret = True self.call2sdfg(augmented_call, sdfg, cfg) @@ -1090,7 +1091,8 @@ def create_ast_from_string( program = ast_transforms.ArrayToLoop(program).visit(program) for transformation in own_ast.fortran_intrinsics().transformations(): - program = transformation(program).visit(program) + transformation.initialize(program) + program = transformation.visit(program) program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) @@ -1126,7 +1128,8 @@ def create_sdfg_from_string( program = ast_transforms.ArrayToLoop(program).visit(program) for transformation in own_ast.fortran_intrinsics().transformations(): - program = transformation(program).visit(program) + transformation.initialize(program) + program = transformation.visit(program) program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) @@ -1172,7 +1175,8 @@ def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_block program = ast_transforms.ArrayToLoop(program).visit(program) for transformation in own_ast.fortran_intrinsics(): - program = transformation(program).visit(program) + transformation.initialize(program) + program = transformation.visit(program) program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program).visit(program) diff --git a/dace/frontend/fortran/intrinsics.py b/dace/frontend/fortran/intrinsics.py index c2e5afe79b..af44a8dfb5 100644 --- a/dace/frontend/fortran/intrinsics.py +++ b/dace/frontend/fortran/intrinsics.py @@ -2,6 +2,7 @@ from abc import abstractmethod import copy import math +from collections import namedtuple from typing import Any, List, Optional, Set, Tuple, Type from dace.frontend.fortran import ast_internal_classes @@ -26,34 +27,175 @@ def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classe def has_transformation() -> bool: return False -class SelectedKind(IntrinsicTransformation): +class IntrinsicNodeTransformer(NodeTransformer): + + def initialize(self, ast): + # We need to rerun the assignment because transformations could have created + # new AST nodes + ParentScopeAssigner().visit(ast) + self.scope_vars = ScopeVarsDeclarations() + self.scope_vars.visit(ast) + + @staticmethod + @abstractmethod + def func_name(self) -> str: + pass + +class DirectReplacement(IntrinsicTransformation): + + Replacement = namedtuple("Replacement", "function") + Transformation = namedtuple("Transformation", "function") + + class ASTTransformation(IntrinsicNodeTransformer): + + def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): + + if not isinstance(binop_node.rval, ast_internal_classes.Call_Expr_Node): + return binop_node + + node = binop_node.rval + + name = node.name.name.split('__dace_') + if len(name) != 2 or name[1] not in DirectReplacement.FUNCTIONS: + return binop_node + func_name = name[1] + + replacement_rule = DirectReplacement.FUNCTIONS[func_name] + if isinstance(replacement_rule, DirectReplacement.Transformation): + + # FIXME: we do not have line number in binop? + binop_node.rval, input_type = replacement_rule.function(node, self.scope_vars, 0) #binop_node.line) + print(binop_node, binop_node.lval, binop_node.rval) + + # replace types of return variable - LHS of the binary operator + var = binop_node.lval + if isinstance(var.name, ast_internal_classes.Name_Node): + name = var.name.name + else: + name = var.name + var_decl = self.scope_vars.get_var(var.parent, name) + var.type = input_type + var_decl.type = input_type + + return binop_node + + + #self.scope_vars.get_var(node.parent, arg.name). + + def replace_size(var: ast_internal_classes.Call_Expr_Node, scope_vars: ScopeVarsDeclarations, line): + + if len(var.args) not in [1, 2]: + raise RuntimeError() + + # get variable declaration for the first argument + var_decl = scope_vars.get_var(var.parent, var.args[0].name) + + # one arg to SIZE? compute the total number of elements + if len(var.args) == 1: + + if len(var_decl.sizes) == 1: + return (var_decl.sizes[0], "INTEGER") + + ret = ast_internal_classes.BinOp_Node( + lval=var_decl.sizes[0], + rval=None, + op="*" + ) + cur_node = ret + for i in range(1, len(var_decl.sizes) - 1): + + cur_node.rval = ast_internal_classes.BinOp_Node( + lval=var_decl.sizes[i], + rval=None, + op="*" + ) + cur_node = cur_node.rval + + cur_node.rval = var_decl.sizes[-1] + return (ret, "INTEGER") + + # two arguments? We return number of elements in a given rank + rank = var.args[1] + # we do not support symbolic argument to DIM - it must be a literal + if not isinstance(rank, ast_internal_classes.Int_Literal_Node): + raise NotImplementedError() + value = int(rank.value) + return (var_decl.sizes[value-1], "INTEGER") + + + def replace_bit_size(var: ast_internal_classes.Call_Expr_Node, scope_vars: ScopeVarsDeclarations, line): + + if len(var.args) != 1: + raise RuntimeError() + + # get variable declaration for the first argument + var_decl = scope_vars.get_var(var.parent, var.args[0].name) + + dace_type = fortrantypes2dacetypes[var_decl.type] + type_size = dace_type().itemsize * 8 + + return (ast_internal_classes.Int_Literal_Node(value=str(type_size)), "INTEGER") + + + def replace_int_kind(args: ast_internal_classes.Arg_List_Node, line): + return ast_internal_classes.Int_Literal_Node(value=str( + math.ceil((math.log2(math.pow(10, int(args.args[0].value))) + 1) / 8)), + line_number=line) + + def replace_real_kind(args: ast_internal_classes.Arg_List_Node, line): + if int(args.args[0].value) >= 9 or int(args.args[1].value) > 126: + return ast_internal_classes.Int_Literal_Node(value="8", line_number=line) + elif int(args.args[0].value) >= 3 or int(args.args[1].value) > 14: + return ast_internal_classes.Int_Literal_Node(value="4", line_number=line) + else: + return ast_internal_classes.Int_Literal_Node(value="2", line_number=line) + FUNCTIONS = { - "SELECTED_INT_KIND": "__dace_selected_int_kind", - "SELECTED_REAL_KIND": "__dace_selected_real_kind", + "SELECTED_INT_KIND": Replacement(replace_int_kind), + "SELECTED_REAL_KIND": Replacement(replace_real_kind), + "BIT_SIZE": Transformation(replace_bit_size), + "SIZE": Transformation(replace_size) } @staticmethod - def replaced_name(func_name: str) -> str: - return SelectedKind.FUNCTIONS[func_name] + def temporary_functions(): + + # temporary functions created by us -> f becomes __dace_f + # We provide this to tell Fortran parser that these are function calls, + # not array accesses + funcs = list(DirectReplacement.FUNCTIONS.keys()) + return [f'__dace_{f}' for f in funcs] + + @staticmethod + def replacable_name(func_name: str) -> bool: + return func_name in DirectReplacement.FUNCTIONS + + @staticmethod + def replace_name(func_name: str) -> str: + #return ast_internal_classes.Name_Node(name=DirectReplacement.FUNCTIONS[func_name][0]) + return ast_internal_classes.Name_Node(name=f'__dace_{func_name}') + + @staticmethod + def replacable(func_name: str) -> bool: + orig_name = func_name.split('__dace_') + if len(orig_name) > 1 and orig_name[1] in DirectReplacement.FUNCTIONS: + return isinstance(DirectReplacement.FUNCTIONS[orig_name[1]], DirectReplacement.Replacement) + return False @staticmethod def replace(func_name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line) -> ast_internal_classes.FNode: - if func_name.name == "__dace_selected_int_kind": - return ast_internal_classes.Int_Literal_Node(value=str( - math.ceil((math.log2(math.pow(10, int(args.args[0].value))) + 1) / 8)), - line_number=line) - # This selects the smallest kind that can hold the given number of digits (fp64,fp32 or fp16) - elif func_name.name == "__dace_selected_real_kind": - if int(args.args[0].value) >= 9 or int(args.args[1].value) > 126: - return ast_internal_classes.Int_Literal_Node(value="8", line_number=line) - elif int(args.args[0].value) >= 3 or int(args.args[1].value) > 14: - return ast_internal_classes.Int_Literal_Node(value="4", line_number=line) - else: - return ast_internal_classes.Int_Literal_Node(value="2", line_number=line) + # Here we already have __dace_func + fname = func_name.split('__dace_')[1] + return DirectReplacement.FUNCTIONS[fname].function(args, line) - raise NotImplemented() + def has_transformation(fname: str) -> bool: + return isinstance(DirectReplacement.FUNCTIONS[fname], DirectReplacement.Transformation) + + @staticmethod + def get_transformation() -> IntrinsicNodeTransformer: + return DirectReplacement.ASTTransformation() class LoopBasedReplacement: @@ -84,36 +226,34 @@ class LoopBasedReplacementVisitor(NodeVisitor): def __init__(self, func_name: str): self._func_name = func_name self.nodes: List[ast_internal_classes.FNode] = [] + self.calls: List[ast_internal_classes.FNode] = [] def visit_BinOp_Node(self, node: ast_internal_classes.BinOp_Node): - if isinstance(node.rval, ast_internal_classes.Call_Expr_Node): if node.rval.name.name == self._func_name: self.nodes.append(node) + self.calls.append(node.rval) + self.visit(node.lval) + self.visit(node.rval) + + def visit_Call_Expr_Node(self, node: ast_internal_classes.Call_Expr_Node): + + if node.name.name == self._func_name: + if node not in self.calls: + self.nodes.append(node) def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_Node): return -class LoopBasedReplacementTransformation(NodeTransformer): +class LoopBasedReplacementTransformation(IntrinsicNodeTransformer): """ Transforms the AST by removing intrinsic call and replacing it with loops """ - def __init__(self, ast): + def __init__(self): self.count = 0 - - # We need to rerun the assignment because transformations could have created - # new AST nodes - ParentScopeAssigner().visit(ast) - self.scope_vars = ScopeVarsDeclarations() - self.scope_vars.visit(ast) self.rvals = [] - @staticmethod - @abstractmethod - def func_name() -> str: - pass - @abstractmethod def _initialize(self): pass @@ -338,9 +478,6 @@ def visit_Execution_Part_Node(self, node: ast_internal_classes.Execution_Part_No class SumProduct(LoopBasedReplacementTransformation): - def __init__(self, ast): - super().__init__(ast) - def _initialize(self): self.rvals = [] self.argument_variable = None @@ -414,9 +551,6 @@ class Sum(LoopBasedReplacement): class Transformation(SumProduct): - def __init__(self, ast): - super().__init__(ast) - @staticmethod def func_name() -> str: return "__dace_sum" @@ -440,9 +574,6 @@ class Product(LoopBasedReplacement): class Transformation(SumProduct): - def __init__(self, ast): - super().__init__(ast) - @staticmethod def func_name() -> str: return "__dace_product" @@ -455,9 +586,6 @@ def _result_update_op(self): class AnyAllCountTransformation(LoopBasedReplacementTransformation): - def __init__(self, ast): - super().__init__(ast) - def _initialize(self): self.rvals = [] @@ -575,9 +703,6 @@ class Any(LoopBasedReplacement): """ class Transformation(AnyAllCountTransformation): - def __init__(self, ast): - super().__init__(ast) - def _result_init_value(self): return "0" @@ -607,9 +732,6 @@ class All(LoopBasedReplacement): """ class Transformation(AnyAllCountTransformation): - def __init__(self, ast): - super().__init__(ast) - def _result_init_value(self): return "1" @@ -644,9 +766,6 @@ class Count(LoopBasedReplacement): """ class Transformation(AnyAllCountTransformation): - def __init__(self, ast): - super().__init__(ast) - def _result_init_value(self): return "0" @@ -675,9 +794,6 @@ def func_name() -> str: class MinMaxValTransformation(LoopBasedReplacementTransformation): - def __init__(self, ast): - super().__init__(ast) - def _initialize(self): self.rvals = [] self.argument_variable = None @@ -753,9 +869,6 @@ class MinVal(LoopBasedReplacement): """ class Transformation(MinMaxValTransformation): - def __init__(self, ast): - super().__init__(ast) - def _result_init_value(self, array: ast_internal_classes.Array_Subscript_Node): var_decl = self.scope_vars.get_var(array.parent, array.name.name) @@ -788,9 +901,6 @@ class MaxVal(LoopBasedReplacement): """ class Transformation(MinMaxValTransformation): - def __init__(self, ast): - super().__init__(ast) - def _result_init_value(self, array: ast_internal_classes.Array_Subscript_Node): var_decl = self.scope_vars.get_var(array.parent, array.name.name) @@ -817,9 +927,6 @@ class Merge(LoopBasedReplacement): class Transformation(LoopBasedReplacementTransformation): - def __init__(self, ast): - super().__init__(ast) - def _initialize(self): self.rvals = [] @@ -939,11 +1046,235 @@ def _generate_loop_body(self, node: ast_internal_classes.FNode) -> ast_internal_ line_number=node.line_number ) +class MathFunctions(IntrinsicTransformation): + + MathTransformation = namedtuple("MathTransformation", "function return_type") + MathReplacement = namedtuple("MathReplacement", "function replacement_function return_type") + + def generate_scale(arg: ast_internal_classes.Call_Expr_Node): + + # SCALE(X, I) becomes: X * pow(RADIX(X), I) + # In our case, RADIX(X) is always 2 + line = arg.line_number + x = arg.args[0] + i = arg.args[1] + const_two = ast_internal_classes.Int_Literal_Node(value="2") + + # I and RADIX(X) are both integers + rval = ast_internal_classes.Call_Expr_Node( + name=ast_internal_classes.Name_Node(name="pow"), + type="INTEGER", + args=[const_two, i], + line_number=line + ) + + mult = ast_internal_classes.BinOp_Node( + op="*", + lval=x, + rval=rval, + line_number=line + ) + + # pack it into parentheses, just to be sure + return ast_internal_classes.Parenthesis_Expr_Node(expr=mult) + + def generate_aint(arg: ast_internal_classes.Call_Expr_Node): + + # The call to AINT can contain a second KIND parameter + # We ignore it a the moment. + # However, to map into C's trunc, we need to drop it. + if len(arg.args) > 1: + del arg.args[1] + + fname = arg.name.name.split('__dace_')[1] + if fname in "AINT": + arg.name = ast_internal_classes.Name_Node(name="trunc") + elif fname == "NINT": + arg.name = ast_internal_classes.Name_Node(name="iround") + elif fname == "ANINT": + arg.name = ast_internal_classes.Name_Node(name="round") + else: + raise NotImplementedError() + + return arg + + INTRINSIC_TO_DACE = { + "MIN": MathTransformation("min", "FIRST_ARG"), + "MAX": MathTransformation("max", "FIRST_ARG"), + "SQRT": MathTransformation("sqrt", "FIRST_ARG"), + "ABS": MathTransformation("abs", "FIRST_ARG"), + "EXP": MathTransformation("exp", "FIRST_ARG"), + # Documentation states that the return type of LOG is always REAL, + # but the kind is the same as of the first argument. + # However, we already replaced kind with types used in DaCe. + # Thus, a REAL that is really DOUBLE will be double in the first argument. + "LOG": MathTransformation("log", "FIRST_ARG"), + "MOD": { + "INTEGER": MathTransformation("Mod", "INTEGER"), + "REAL": MathTransformation("Mod_float", "REAL"), + "DOUBLE": MathTransformation("Mod_float", "DOUBLE") + }, + "MODULO": { + "INTEGER": MathTransformation("Modulo", "INTEGER"), + "REAL": MathTransformation("Modulo_float", "REAL"), + "DOUBLE": MathTransformation("Modulo_float", "DOUBLE") + }, + "FLOOR": { + "REAL": MathTransformation("floor", "INTEGER"), + "DOUBLE": MathTransformation("floor", "INTEGER") + }, + "SCALE": MathReplacement(None, generate_scale, "FIRST_ARG"), + "EXPONENT": MathTransformation("frexp", "INTEGER"), + "INT": MathTransformation("int", "INTEGER"), + "AINT": MathReplacement("trunc", generate_aint, "FIRST_ARG"), + "NINT": MathReplacement("iround", generate_aint, "INTEGER"), + "ANINT": MathReplacement("round", generate_aint, "FIRST_ARG"), + "REAL": MathTransformation("float", "REAL"), + "DBLE": MathTransformation("double", "DOUBLE"), + "SIN": MathTransformation("sin", "FIRST_ARG"), + "COS": MathTransformation("cos", "FIRST_ARG"), + "SINH": MathTransformation("sinh", "FIRST_ARG"), + "COSH": MathTransformation("cosh", "FIRST_ARG"), + "TANH": MathTransformation("tanh", "FIRST_ARG"), + "ASIN": MathTransformation("asin", "FIRST_ARG"), + "ACOS": MathTransformation("acos", "FIRST_ARG"), + "ATAN": MathTransformation("atan", "FIRST_ARG"), + "ATAN2": MathTransformation("atan2", "FIRST_ARG") + } + + class TypeTransformer(IntrinsicNodeTransformer): + + def func_type(self, node: ast_internal_classes.Call_Expr_Node): + + # take the first arg + arg = node.args[0] + if isinstance(arg, ast_internal_classes.Real_Literal_Node): + return 'REAL' + elif isinstance(arg, ast_internal_classes.Int_Literal_Node): + return 'INTEGER' + elif isinstance(arg, ast_internal_classes.Call_Expr_Node): + return arg.type + elif isinstance(arg, ast_internal_classes.Name_Node): + input_type = self.scope_vars.get_var(node.parent, arg.name) + return input_type.type + else: + input_type = self.scope_vars.get_var(node.parent, arg.name.name) + return input_type.type + + def replace_call(self, old_call: ast_internal_classes.Call_Expr_Node, new_call: ast_internal_classes.FNode): + + parent = old_call.parent + + # We won't need it if the CallExtractor will properly support nested function calls. + # Then, all function calls should be a binary op: val = func() + if isinstance(parent, ast_internal_classes.BinOp_Node): + if parent.lval == old_call: + parent.lval = new_call + else: + parent.rval = new_call + elif isinstance(parent, ast_internal_classes.UnOp_Node): + parent.lval = new_call + elif isinstance(parent, ast_internal_classes.Parenthesis_Expr_Node): + parent.expr = new_call + elif isinstance(parent, ast_internal_classes.Call_Expr_Node): + for idx, arg in enumerate(parent.args): + if arg == old_call: + parent.args[idx] = new_call + break + else: + raise NotImplementedError() + + def visit_BinOp_Node(self, binop_node: ast_internal_classes.BinOp_Node): + + if not isinstance(binop_node.rval, ast_internal_classes.Call_Expr_Node): + return binop_node + + node = binop_node.rval + + name = node.name.name.split('__dace_') + if len(name) != 2 or name[1] not in MathFunctions.INTRINSIC_TO_DACE: + return binop_node + func_name = name[1] + + # Visit all children before we expand this call. + # We need that to properly get the type. + for arg in node.args: + self.visit(arg) + + return_type = None + input_type = None + input_type = self.func_type(node) + + replacement_rule = MathFunctions.INTRINSIC_TO_DACE[func_name] + if isinstance(replacement_rule, dict): + replacement_rule = replacement_rule[input_type] + if replacement_rule.return_type == "FIRST_ARG": + return_type = input_type + else: + return_type = replacement_rule.return_type + + if isinstance(replacement_rule, MathFunctions.MathTransformation): + node.name = ast_internal_classes.Name_Node(name=replacement_rule.function) + node.type = return_type + + else: + binop_node.rval = replacement_rule.replacement_function(node) + + # replace types of return variable - LHS of the binary operator + var = binop_node.lval + name = None + if isinstance(var.name, ast_internal_classes.Name_Node): + name = var.name.name + else: + name = var.name + var_decl = self.scope_vars.get_var(var.parent, name) + var.type = input_type + var_decl.type = input_type + + return binop_node + + @staticmethod + def dace_functions(): + + # list of final dace functions which we create + funcs = list(MathFunctions.INTRINSIC_TO_DACE.values()) + res = [] + # flatten nested lists + for f in funcs: + if isinstance(f, dict): + res.extend([v.function for k, v in f.items() if v.function is not None]) + else: + if f.function is not None: + res.append(f.function) + return res + + @staticmethod + def temporary_functions(): + + # temporary functions created by us -> f becomes __dace_f + # We provide this to tell Fortran parser that these are function calls, + # not array accesses + funcs = list(MathFunctions.INTRINSIC_TO_DACE.keys()) + return [f'__dace_{f}' for f in funcs] + + @staticmethod + def replacable(func_name: str) -> bool: + return func_name in MathFunctions.INTRINSIC_TO_DACE + + @staticmethod + def replace(func_name: str) -> ast_internal_classes.FNode: + return ast_internal_classes.Name_Node(name=f'__dace_{func_name}') + + def has_transformation() -> bool: + return True + + @staticmethod + def get_transformation() -> TypeTransformer: + return MathFunctions.TypeTransformer() + class FortranIntrinsics: IMPLEMENTATIONS_AST = { - "SELECTED_INT_KIND": SelectedKind, - "SELECTED_REAL_KIND": SelectedKind, "SUM": Sum, "PRODUCT": Product, "ANY": Any, @@ -954,11 +1285,6 @@ class FortranIntrinsics: "MERGE": Merge } - DIRECT_REPLACEMENTS = { - "__dace_selected_int_kind": SelectedKind, - "__dace_selected_real_kind": SelectedKind - } - EXEMPTED_FROM_CALL_EXTRACTION = [ Merge ] @@ -971,59 +1297,58 @@ def transformations(self) -> Set[Type[NodeTransformer]]: @staticmethod def function_names() -> List[str]: - return list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()) + # list of all functions that are created by initial transformation, before doing full replacement + # this prevents other parser components from replacing our function calls with array subscription nodes + return [*list(LoopBasedReplacement.INTRINSIC_TO_DACE.values()), *MathFunctions.temporary_functions(), *DirectReplacement.temporary_functions()] + + @staticmethod + def retained_function_names() -> List[str]: + # list of all DaCe functions that we use after full parsing + return MathFunctions.dace_functions() @staticmethod def call_extraction_exemptions() -> List[str]: - return [func.Transformation.func_name() for func in FortranIntrinsics.EXEMPTED_FROM_CALL_EXTRACTION] + return [ + *[func.Transformation.func_name() for func in FortranIntrinsics.EXEMPTED_FROM_CALL_EXTRACTION] + #*MathFunctions.temporary_functions() + ] def replace_function_name(self, node: FASTNode) -> ast_internal_classes.Name_Node: func_name = node.string replacements = { - "INT": "__dace_int", - "DBLE": "__dace_dble", - "SQRT": "sqrt", - "COSH": "cosh", - "ABS": "abs", - "MIN": "min", - "MAX": "max", - "EXP": "exp", - "EPSILON": "__dace_epsilon", - "TANH": "tanh", "SIGN": "__dace_sign", - "EXP": "exp" } if func_name in replacements: return ast_internal_classes.Name_Node(name=replacements[func_name]) - else: - - if self.IMPLEMENTATIONS_AST[func_name].has_transformation(): - self._transformations_to_run.add(self.IMPLEMENTATIONS_AST[func_name].Transformation) + elif DirectReplacement.replacable_name(func_name): + if DirectReplacement.has_transformation(func_name): + self._transformations_to_run.add(DirectReplacement.get_transformation()) + return DirectReplacement.replace_name(func_name) + elif MathFunctions.replacable(func_name): + self._transformations_to_run.add(MathFunctions.get_transformation()) + return MathFunctions.replace(func_name) + + if self.IMPLEMENTATIONS_AST[func_name].has_transformation(): + + if hasattr(self.IMPLEMENTATIONS_AST[func_name], "Transformation"): + self._transformations_to_run.add(self.IMPLEMENTATIONS_AST[func_name].Transformation()) + else: + self._transformations_to_run.add(self.IMPLEMENTATIONS_AST[func_name].get_transformation(func_name)) - return ast_internal_classes.Name_Node(name=self.IMPLEMENTATIONS_AST[func_name].replaced_name(func_name)) + return ast_internal_classes.Name_Node(name=self.IMPLEMENTATIONS_AST[func_name].replaced_name(func_name)) def replace_function_reference(self, name: ast_internal_classes.Name_Node, args: ast_internal_classes.Arg_List_Node, line): func_types = { - "__dace_int": "INT", - "__dace_dble": "DOUBLE", - "sqrt": "DOUBLE", - "cosh": "DOUBLE", - "abs": "DOUBLE", - "min": "DOUBLE", - "max": "DOUBLE", - "exp": "DOUBLE", - "__dace_epsilon": "DOUBLE", - "tanh": "DOUBLE", "__dace_sign": "DOUBLE", } if name.name in func_types: # FIXME: this will be progressively removed call_type = func_types[name.name] return ast_internal_classes.Call_Expr_Node(name=name, type=call_type, args=args.args, line_number=line) - elif name.name in self.DIRECT_REPLACEMENTS: - return self.DIRECT_REPLACEMENTS[name.name].replace(name, args, line) + elif DirectReplacement.replacable(name.name): + return DirectReplacement.replace(name.name, args, line) else: # We will do the actual type replacement later # To that end, we need to know the input types - but these we do not know at the moment. diff --git a/dace/runtime/include/dace/math.h b/dace/runtime/include/dace/math.h index 0a9d153767..4dae494a8a 100644 --- a/dace/runtime/include/dace/math.h +++ b/dace/runtime/include/dace/math.h @@ -61,6 +61,45 @@ static DACE_CONSTEXPR DACE_HDFI T Mod(const T& value, const T2& modulus) { return value % modulus; } +// Fortran implements MOD for floating-point values as well +template +static DACE_CONSTEXPR DACE_HDFI T Mod_float(const T& value, const T& modulus) { + return value - static_cast(value / modulus) * modulus; +} + +// Fortran implementation of MODULO +template +static DACE_CONSTEXPR DACE_HDFI T Modulo(const T& value, const T& modulus) { + // Fortran implementation for integers - find R such that value = Q * modulus + R + // However, R must be in [0, modulus) + // To achieve that, we need to cast the division to floats. + // Example: -17, 3 must produce 1 and not -2. + // If we don't use cast, the floor is called on -5, producing wrong value. + // Instead, we need to have floor(-5.6... ) to ensure it produces -6. + // Similarly, 17, -3 must produce -1 and not 2. + // This means that the default solution works if value and modulus have the same sign. + return value - floor(static_cast(value) / modulus) * modulus; +} + +template +static DACE_CONSTEXPR DACE_HDFI T Modulo_float(const T& value, const T& modulus) { + return value - floor(value / modulus) * modulus; +} + +// Implement to support a match wtih Fortran's intrinsic EXPONENT +template::value>* = nullptr> +static DACE_CONSTEXPR DACE_HDFI int frexp(const T& a) { + int exponent = 0; + std::frexp(a, &exponent); + return exponent; +} + +// Implement to support Fortran's intrinsic NINT - round, but return an integer +template::value>* = nullptr> +static DACE_CONSTEXPR DACE_HDFI int iround(const T& a) { + return static_cast(round(a)); +} + template static DACE_CONSTEXPR DACE_HDFI T int_ceil(const T& numerator, const T2& denominator) { return (numerator + denominator - 1) / denominator; diff --git a/tests/fortran/call_extract_test.py b/tests/fortran/call_extract_test.py new file mode 100644 index 0000000000..eb1f2ac86d --- /dev/null +++ b/tests/fortran/call_extract_test.py @@ -0,0 +1,39 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_call_extract(): + test_string = """ + PROGRAM intrinsic_call_extract + implicit none + real, dimension(2) :: d + real, dimension(2) :: res + CALL intrinsic_call_extract_test_function(d,res) + end + + SUBROUTINE intrinsic_call_extract_test_function(d,res) + real, dimension(2) :: d + real, dimension(2) :: res + + res(1) = SQRT(SIGN(EXP(d(1)), LOG(d(1)))) + res(2) = MIN(SQRT(EXP(d(1))), SQRT(EXP(d(1))) - 1) + + END SUBROUTINE intrinsic_call_extract_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_call_extract", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + input = np.full([2], 42, order="F", dtype=np.float32) + res = np.full([2], 42, order="F", dtype=np.float32) + sdfg(d=input, res=res) + assert np.allclose(res, [np.sqrt(np.exp(input[0])), np.sqrt(np.exp(input[0])) - 1]) + + +if __name__ == "__main__": + + test_fortran_frontend_call_extract() diff --git a/tests/fortran/intrinsic_basic_test.py b/tests/fortran/intrinsic_basic_test.py new file mode 100644 index 0000000000..9ef31dd108 --- /dev/null +++ b/tests/fortran/intrinsic_basic_test.py @@ -0,0 +1,98 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_bit_size(): + test_string = """ + PROGRAM intrinsic_math_test_bit_size + implicit none + integer, dimension(4) :: res + CALL intrinsic_math_test_function(res) + end + + SUBROUTINE intrinsic_math_test_function(res) + integer, dimension(4) :: res + logical :: a = .TRUE. + integer :: b = 1 + real :: c = 1 + double precision :: d = 1 + + res(1) = BIT_SIZE(a) + res(2) = BIT_SIZE(b) + res(3) = BIT_SIZE(c) + res(4) = BIT_SIZE(d) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_bit_size", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + res = np.full([size], 42, order="F", dtype=np.int32) + sdfg(res=res) + + assert np.allclose(res, [32, 32, 32, 64]) + +def test_fortran_frontend_bit_size_symbolic(): + test_string = """ + PROGRAM intrinsic_math_test_bit_size + implicit none + integer, parameter :: arrsize = 2 + integer, parameter :: arrsize2 = 3 + integer, parameter :: arrsize3 = 4 + integer :: res(arrsize) + integer :: res2(arrsize, arrsize2, arrsize3) + integer :: res3(arrsize+arrsize2, arrsize2 * 5, arrsize3 + arrsize2*arrsize) + CALL intrinsic_math_test_function(arrsize, arrsize2, arrsize3, res, res2, res3) + end + + SUBROUTINE intrinsic_math_test_function(arrsize, arrsize2, arrsize3, res, res2, res3) + implicit none + integer :: arrsize + integer :: arrsize2 + integer :: arrsize3 + integer :: res(arrsize) + integer :: res2(arrsize, arrsize2, arrsize3) + integer :: res3(arrsize+arrsize2, arrsize2 * 5, arrsize3 + arrsize2*arrsize) + + res(1) = SIZE(res) + res(2) = SIZE(res2) + res(3) = SIZE(res3) + res(4) = SIZE(res)*2 + res(5) = SIZE(res)*SIZE(res2)*SIZE(res3) + res(6) = SIZE(res2, 1) + SIZE(res2, 2) + SIZE(res2, 3) + res(7) = SIZE(res3, 1) + SIZE(res3, 2) + SIZE(res3, 3) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_bit_size", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 24 + size2 = 5 + size3 = 7 + res = np.full([size], 42, order="F", dtype=np.int32) + res2 = np.full([size, size2, size3], 42, order="F", dtype=np.int32) + res3 = np.full([size+size2, size2*5, size3 + size*size2], 42, order="F", dtype=np.int32) + sdfg(res=res, res2=res2, res3=res3, arrsize=size, arrsize2=size2, arrsize3=size3) + print(res) + + assert res[0] == size + assert res[1] == size*size2*size3 + assert res[2] == (size + size2) * (size2 * 5) * (size3 + size2*size) + assert res[3] == size * 2 + assert res[4] == res[0] * res[1] * res[2] + assert res[5] == size + size2 + size3 + assert res[6] == size + size2 + size2*5 + size3 + size*size2 + + +if __name__ == "__main__": + test_fortran_frontend_bit_size() + test_fortran_frontend_bit_size_symbolic() diff --git a/tests/fortran/intrinsic_math_test.py b/tests/fortran/intrinsic_math_test.py new file mode 100644 index 0000000000..e1fc469beb --- /dev/null +++ b/tests/fortran/intrinsic_math_test.py @@ -0,0 +1,845 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np +import pytest + +from dace.frontend.fortran import fortran_parser + +def test_fortran_frontend_min_max(): + test_string = """ + PROGRAM intrinsic_math_test_min_max + implicit none + double precision, dimension(2) :: arg1 + double precision, dimension(2) :: arg2 + double precision, dimension(2) :: res1 + double precision, dimension(2) :: res2 + CALL intrinsic_math_test_function(arg1, arg2, res1, res2) + end + + SUBROUTINE intrinsic_math_test_function(arg1, arg2, res1, res2) + double precision, dimension(2) :: arg1 + double precision, dimension(2) :: arg2 + double precision, dimension(2) :: res1 + double precision, dimension(2) :: res2 + + res1(1) = MIN(arg1(1), arg2(1)) + res1(2) = MIN(arg1(2), arg2(2)) + + res2(1) = MAX(arg1(1), arg2(1)) + res2(2) = MAX(arg1(2), arg2(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_min_max", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + arg1 = np.full([size], 42, order="F", dtype=np.float64) + arg2 = np.full([size], 42, order="F", dtype=np.float64) + + arg1[0] = 20 + arg1[1] = 25 + arg2[0] = 30 + arg2[1] = 18 + + res1 = np.full([2], 42, order="F", dtype=np.float64) + res2 = np.full([2], 42, order="F", dtype=np.float64) + sdfg(arg1=arg1, arg2=arg2, res1=res1, res2=res2) + + assert res1[0] == 20 + assert res1[1] == 18 + assert res2[0] == 30 + assert res2[1] == 25 + + +def test_fortran_frontend_sqrt(): + test_string = """ + PROGRAM intrinsic_math_test_sqrt + implicit none + double precision, dimension(2) :: d + double precision, dimension(2) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(2) :: d + double precision, dimension(2) :: res + + res(1) = SQRT(d(1)) + res(2) = SQRT(d(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_sqrt", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 2 + d[1] = 5 + res = np.full([2], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + py_res = np.sqrt(d) + + for f_res, p_res in zip(res, py_res): + assert abs(f_res - p_res) < 10**-9 + +def test_fortran_frontend_abs(): + test_string = """ + PROGRAM intrinsic_math_test_abs + implicit none + double precision, dimension(2) :: d + double precision, dimension(2) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(2) :: d + double precision, dimension(2) :: res + + res(1) = ABS(d(1)) + res(2) = ABS(d(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_abs", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = -30 + d[1] = 40 + res = np.full([2], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + + assert res[0] == 30 + assert res[1] == 40 + +def test_fortran_frontend_exp(): + test_string = """ + PROGRAM intrinsic_math_test_exp + implicit none + double precision, dimension(2) :: d + double precision, dimension(2) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(2) :: d + double precision, dimension(2) :: res + + res(1) = EXP(d(1)) + res(2) = EXP(d(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 2 + d[1] = 4.5 + res = np.full([2], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + py_res = np.exp(d) + + for f_res, p_res in zip(res, py_res): + assert abs(f_res - p_res) < 10**-9 + +def test_fortran_frontend_log(): + test_string = """ + PROGRAM intrinsic_math_test_log + implicit none + double precision, dimension(2) :: d + double precision, dimension(2) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(2) :: d + double precision, dimension(2) :: res + + res(1) = LOG(d(1)) + res(2) = LOG(d(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 2.71 + d[1] = 4.5 + res = np.full([2], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + py_res = np.log(d) + + for f_res, p_res in zip(res, py_res): + assert abs(f_res - p_res) < 10**-9 + +def test_fortran_frontend_log(): + test_string = """ + PROGRAM intrinsic_math_test_log + implicit none + double precision, dimension(2) :: d + double precision, dimension(2) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(2) :: d + double precision, dimension(2) :: res + + res(1) = LOG(d(1)) + res(2) = LOG(d(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_exp", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 2.71 + d[1] = 4.5 + res = np.full([2], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + py_res = np.log(d) + + for f_res, p_res in zip(res, py_res): + assert abs(f_res - p_res) < 10**-9 + +def test_fortran_frontend_mod_float(): + test_string = """ + PROGRAM intrinsic_math_test_mod + implicit none + double precision, dimension(12) :: d + double precision, dimension(6) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(12) :: d + double precision, dimension(6) :: res + + res(1) = MOD(d(1), d(2)) + res(2) = MOD(d(3), d(4)) + res(3) = MOD(d(5), d(6)) + res(4) = MOD(d(7), d(8)) + res(5) = MOD(d(9), d(10)) + res(6) = MOD(d(11), d(12)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_mod", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 12 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 17. + d[1] = 3. + d[2] = -17. + d[3] = 3. + d[4] = 17. + d[5] = -3. + d[6] = -17. + d[7] = -3. + d[8] = 17.5 + d[9] = 5.5 + d[10] = -17.5 + d[11] = 5.5 + res = np.full([6], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + + assert res[0] == 2.0 + assert res[1] == -2.0 + assert res[2] == 2.0 + assert res[3] == -2.0 + assert res[4] == 1 + assert res[5] == -1 + +def test_fortran_frontend_mod_integer(): + test_string = """ + PROGRAM intrinsic_math_test_mod + implicit none + integer, dimension(8) :: d + integer, dimension(4) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + integer, dimension(8) :: d + integer, dimension(4) :: res + + res(1) = MOD(d(1), d(2)) + res(2) = MOD(d(3), d(4)) + res(3) = MOD(d(5), d(6)) + res(4) = MOD(d(7), d(8)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 12 + d = np.full([size], 42, order="F", dtype=np.int32) + d[0] = 17 + d[1] = 3 + d[2] = -17 + d[3] = 3 + d[4] = 17 + d[5] = -3 + d[6] = -17 + d[7] = -3 + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + assert res[0] == 2 + assert res[1] == -2 + assert res[2] == 2 + assert res[3] == -2 + +def test_fortran_frontend_modulo_float(): + test_string = """ + PROGRAM intrinsic_math_test_modulo + implicit none + double precision, dimension(12) :: d + double precision, dimension(6) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + double precision, dimension(12) :: d + double precision, dimension(6) :: res + + res(1) = MODULO(d(1), d(2)) + res(2) = MODULO(d(3), d(4)) + res(3) = MODULO(d(5), d(6)) + res(4) = MODULO(d(7), d(8)) + res(5) = MODULO(d(9), d(10)) + res(6) = MODULO(d(11), d(12)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 12 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 17. + d[1] = 3. + d[2] = -17. + d[3] = 3. + d[4] = 17. + d[5] = -3. + d[6] = -17. + d[7] = -3. + d[8] = 17.5 + d[9] = 5.5 + d[10] = -17.5 + d[11] = 5.5 + res = np.full([6], 42, order="F", dtype=np.float64) + sdfg(d=d, res=res) + + assert res[0] == 2.0 + assert res[1] == 1.0 + assert res[2] == -1.0 + assert res[3] == -2.0 + assert res[4] == 1.0 + assert res[5] == 4.5 + +def test_fortran_frontend_modulo_integer(): + test_string = """ + PROGRAM intrinsic_math_test_modulo + implicit none + integer, dimension(8) :: d + integer, dimension(4) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + integer, dimension(8) :: d + integer, dimension(4) :: res + + res(1) = MODULO(d(1), d(2)) + res(2) = MODULO(d(3), d(4)) + res(3) = MODULO(d(5), d(6)) + res(4) = MODULO(d(7), d(8)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 12 + d = np.full([size], 42, order="F", dtype=np.int32) + d[0] = 17 + d[1] = 3 + d[2] = -17 + d[3] = 3 + d[4] = 17 + d[5] = -3 + d[6] = -17 + d[7] = -3 + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == 2 + assert res[1] == 1 + assert res[2] == -1 + assert res[3] == -2 + +def test_fortran_frontend_floor(): + test_string = """ + PROGRAM intrinsic_math_test_floor + implicit none + real, dimension(4) :: d + integer, dimension(4) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + real, dimension(4) :: d + integer, dimension(4) :: res + + res(1) = FLOOR(d(1)) + res(2) = FLOOR(d(2)) + res(3) = FLOOR(d(3)) + res(4) = FLOOR(d(4)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + d = np.full([size], 42, order="F", dtype=np.float32) + d[0] = 3.5 + d[1] = 63.000001 + d[2] = -3.5 + d[3] = -63.00001 + res = np.full([4], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == 3 + assert res[1] == 63 + assert res[2] == -4 + assert res[3] == -64 + +def test_fortran_frontend_scale(): + test_string = """ + PROGRAM intrinsic_math_test_scale + implicit none + real, dimension(4) :: d + integer, dimension(4) :: d2 + real, dimension(5) :: res + CALL intrinsic_math_test_function(d, d2, res) + end + + SUBROUTINE intrinsic_math_test_function(d, d2, res) + real, dimension(4) :: d + integer, dimension(4) :: d2 + real, dimension(5) :: res + + res(1) = SCALE(d(1), d2(1)) + res(2) = SCALE(d(2), d2(2)) + res(3) = SCALE(d(3), d2(3)) + ! Verifies that we properly replace call even inside a complex expression + res(4) = (SCALE(d(4), d2(4))) + (SCALE(d(4), d2(4))*2) + res(5) = (SCALE(SCALE(d(4), d2(4)), d2(4))) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + d = np.full([size], 42, order="F", dtype=np.float32) + d[0] = 178.1387e-4 + d[1] = 5.5 + d[2] = 5.5 + d[3] = 42.5 + d2 = np.full([size], 42, order="F", dtype=np.int32) + d2[0] = 5 + d2[1] = 5 + d2[2] = 7 + d2[3] = 9 + res = np.full([5], 42, order="F", dtype=np.float32) + sdfg(d=d, d2=d2, res=res) + + assert abs(res[0] - 0.570043862) < 10**-7 + assert res[1] == 176. + assert res[2] == 704. + assert res[3] == 65280. + assert res[4] == 11141120. + +def test_fortran_frontend_exponent(): + test_string = """ + PROGRAM intrinsic_math_test_exponent + implicit none + real, dimension(4) :: d + integer, dimension(4) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + real, dimension(4) :: d + integer, dimension(4) :: res + + res(1) = EXPONENT(d(1)) + res(2) = EXPONENT(d(2)) + res(3) = EXPONENT(d(3)) + res(4) = EXPONENT(d(4)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + d = np.full([size], 42, order="F", dtype=np.float32) + d[0] = 0.0 + d[1] = 1.0 + d[2] = 13.0 + d[3] = 390.0 + res = np.full([5], 42, order="F", dtype=np.int32) + sdfg(d=d, res=res) + + assert res[0] == 0 + assert res[1] == 1 + assert res[2] == 4 + assert res[3] == 9 + +def test_fortran_frontend_int(): + test_string = """ + PROGRAM intrinsic_math_test_int + implicit none + real, dimension(4) :: d + real, dimension(8) :: d2 + integer, dimension(4) :: res + real, dimension(4) :: res2 + integer, dimension(8) :: res3 + real, dimension(8) :: res4 + CALL intrinsic_math_test_function(d, d2, res, res2, res3, res4) + end + + SUBROUTINE intrinsic_math_test_function(d, d2, res, res2, res3, res4) + integer :: n + real, dimension(4) :: d + real, dimension(8) :: d2 + integer, dimension(4) :: res + real, dimension(4) :: res2 + integer, dimension(8) :: res3 + real, dimension(8) :: res4 + + res(1) = INT(d(1)) + res(2) = INT(d(2)) + res(3) = INT(d(3)) + res(4) = INT(d(4)) + + res2(1) = AINT(d(1)) + res2(2) = AINT(d(2)) + res2(3) = AINT(d(3)) + ! KIND parameter is ignored + res2(4) = AINT(d(4), 4) + + DO n=1,8 + ! KIND parameter is ignored + res3(n) = NINT(d2(n), 4) + END DO + + DO n=1,8 + ! KIND parameter is ignored + res4(n) = ANINT(d2(n), 4) + END DO + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 4 + d = np.full([size], 42, order="F", dtype=np.float32) + d[0] = 1.0 + d[1] = 1.5 + d[2] = 42.5 + d[3] = -42.5 + d2 = np.full([size*2], 42, order="F", dtype=np.float32) + d2[0] = 3.49 + d2[1] = 3.5 + d2[2] = 3.51 + d2[3] = 4 + d2[4] = -3.49 + d2[5] = -3.5 + d2[6] = -3.51 + d2[7] = -4 + res = np.full([4], 42, order="F", dtype=np.int32) + res2 = np.full([4], 42, order="F", dtype=np.float32) + res3 = np.full([8], 42, order="F", dtype=np.int32) + res4 = np.full([8], 42, order="F", dtype=np.float32) + sdfg(d=d, d2=d2, res=res, res2=res2, res3=res3, res4=res4) + + assert np.array_equal(res, [1, 1, 42, -42]) + + assert np.array_equal(res2, [1., 1., 42., -42.]) + + assert np.array_equal(res3, [3, 4, 4, 4, -3, -4, -4, -4]) + + assert np.array_equal(res4, [3., 4., 4., 4., -3., -4., -4., -4.]) + +def test_fortran_frontend_real(): + test_string = """ + PROGRAM intrinsic_math_test_real + implicit none + double precision, dimension(2) :: d + real, dimension(2) :: d2 + integer, dimension(2) :: d3 + double precision, dimension(6) :: res + real, dimension(6) :: res2 + CALL intrinsic_math_test_function(d, d2, d3, res, res2) + end + + SUBROUTINE intrinsic_math_test_function(d, d2, d3, res, res2) + integer :: n + double precision, dimension(2) :: d + real, dimension(2) :: d2 + integer, dimension(2) :: d3 + double precision, dimension(6) :: res + real, dimension(6) :: res2 + + res(1) = DBLE(d(1)) + res(2) = DBLE(d(2)) + res(3) = DBLE(d2(1)) + res(4) = DBLE(d2(2)) + res(5) = DBLE(d3(1)) + res(6) = DBLE(d3(2)) + + res2(1) = REAL(d(1)) + res2(2) = REAL(d(2)) + res2(3) = REAL(d2(1)) + res2(4) = REAL(d2(2)) + res2(5) = REAL(d3(1)) + res2(6) = REAL(d3(2)) + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 2 + d = np.full([size], 42, order="F", dtype=np.float64) + d[0] = 7.0 + d[1] = 13.11 + d2 = np.full([size], 42, order="F", dtype=np.float32) + d2[0] = 7.0 + d2[1] = 13.11 + d3 = np.full([size], 42, order="F", dtype=np.int32) + d3[0] = 7 + d3[1] = 13 + + res = np.full([size*3], 42, order="F", dtype=np.float64) + res2 = np.full([size*3], 42, order="F", dtype=np.float32) + sdfg(d=d, d2=d2, d3=d3, res=res, res2=res2) + + assert np.allclose(res, [7.0, 13.11, 7.0, 13.11, 7., 13.]) + assert np.allclose(res2, [7.0, 13.11, 7.0, 13.11, 7., 13.]) + +def test_fortran_frontend_trig(): + test_string = """ + PROGRAM intrinsic_math_test_trig + implicit none + real, dimension(3) :: d + real, dimension(6) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + integer :: n + real, dimension(3) :: d + real, dimension(6) :: res + + DO n=1,3 + res(n) = SIN(d(n)) + END DO + + DO n=1,3 + res(n+3) = COS(d(n)) + END DO + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 3 + d = np.full([size], 42, order="F", dtype=np.float32) + d[0] = 0 + d[1] = 3.14/2 + d[2] = 3.14 + + res = np.full([size*2], 42, order="F", dtype=np.float32) + sdfg(d=d, res=res) + + assert np.allclose(res, [0.0, 0.999999702, 1.59254798E-03, 1.0, 7.96274282E-04, -0.999998748]) + +def test_fortran_frontend_hyperbolic(): + test_string = """ + PROGRAM intrinsic_math_test_hyperbolic + implicit none + real, dimension(3) :: d + real, dimension(9) :: res + CALL intrinsic_math_test_function(d, res) + end + + SUBROUTINE intrinsic_math_test_function(d, res) + integer :: n + real, dimension(3) :: d + real, dimension(9) :: res + + DO n=1,3 + res(n) = SINH(d(n)) + END DO + + DO n=1,3 + res(n+3) = COSH(d(n)) + END DO + + DO n=1,3 + res(n+6) = TANH(d(n)) + END DO + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 3 + d = np.full([size], 42, order="F", dtype=np.float32) + d[0] = 0 + d[1] = 1 + d[2] = 3.14 + + res = np.full([size*3], 42, order="F", dtype=np.float32) + sdfg(d=d, res=res) + + assert np.allclose(res, [0.00000000, 1.17520118, 11.5302935, 1.00000000, 1.54308057, 11.5735760, 0.00000000, 0.761594176, 0.996260226]) + +def test_fortran_frontend_trig_inverse(): + test_string = """ + PROGRAM intrinsic_math_test_hyperbolic + implicit none + real, dimension(3) :: sincos_args + real, dimension(3) :: tan_args + real, dimension(6) :: tan2_args + real, dimension(12) :: res + CALL intrinsic_math_test_function(sincos_args, tan_args, tan2_args, res) + end + + SUBROUTINE intrinsic_math_test_function(sincos_args, tan_args, tan2_args, res) + integer :: n + real, dimension(3) :: sincos_args + real, dimension(3) :: tan_args + real, dimension(6) :: tan2_args + real, dimension(12) :: res + + DO n=1,3 + res(n) = ASIN(sincos_args(n)) + END DO + + DO n=1,3 + res(n+3) = ACOS(sincos_args(n)) + END DO + + DO n=1,3 + res(n+6) = ATAN(tan_args(n)) + END DO + + DO n=1,3 + res(n+9) = ATAN2(tan2_args(2*n - 1), tan2_args(2*n)) + END DO + + END SUBROUTINE intrinsic_math_test_function + """ + + sdfg = fortran_parser.create_sdfg_from_string(test_string, "intrinsic_math_test_modulo", False) + sdfg.simplify(verbose=True) + sdfg.compile() + + size = 3 + sincos_args = np.full([size], 42, order="F", dtype=np.float32) + sincos_args[0] = -0.5 + sincos_args[1] = 0.0 + sincos_args[2] = 1.0 + + atan_args = np.full([size], 42, order="F", dtype=np.float32) + atan_args[0] = 0.0 + atan_args[1] = 1.0 + atan_args[2] = 3.14 + + atan2_args = np.full([size*2], 42, order="F", dtype=np.float32) + atan2_args[0] = 0.0 + atan2_args[1] = 1.0 + atan2_args[2] = 1.0 + atan2_args[3] = 1.0 + atan2_args[4] = 1.0 + atan2_args[5] = 0.0 + + res = np.full([size*4], 42, order="F", dtype=np.float32) + sdfg(sincos_args=sincos_args, tan_args=atan_args, tan2_args=atan2_args, res=res) + + assert np.allclose(res, [-0.523598790, 0.00000000, 1.57079637, 2.09439516, 1.57079637, 0.00000000, 0.00000000, 0.785398185, 1.26248074, 0.00000000, 0.785398185, 1.57079637]) + +if __name__ == "__main__": + + test_fortran_frontend_min_max() + test_fortran_frontend_sqrt() + test_fortran_frontend_abs() + test_fortran_frontend_exp() + test_fortran_frontend_log() + test_fortran_frontend_mod_float() + test_fortran_frontend_mod_integer() + test_fortran_frontend_modulo_float() + test_fortran_frontend_modulo_integer() + test_fortran_frontend_floor() + test_fortran_frontend_scale() + test_fortran_frontend_exponent() + test_fortran_frontend_int() + test_fortran_frontend_real() + test_fortran_frontend_trig() + test_fortran_frontend_hyperbolic() + test_fortran_frontend_trig_inverse()