diff --git a/loki/expression/symbolic.py b/loki/expression/symbolic.py index 24ae7548c..cf66c2b7e 100644 --- a/loki/expression/symbolic.py +++ b/loki/expression/symbolic.py @@ -496,13 +496,16 @@ class Simplification(enum.Flag): Flatten Flatten sub-sums and distribute products. IntegerArithmetic Perform arithmetic on integer literals (addition and multiplication). CollectCoefficients Combine summands as far as possible. + LogicEvaluation Resolve logically fully determinate expressions, like ``1 == 1`` or ``1 == 6`` ALL All of the above. """ Flatten = enum.auto() IntegerArithmetic = enum.auto() CollectCoefficients = enum.auto() + LogicEvaluation = enum.auto() - ALL = Flatten | IntegerArithmetic | CollectCoefficients # pylint: disable=unsupported-binary-operation + # pylint: disable-next=unsupported-binary-operation + ALL = Flatten | IntegerArithmetic | CollectCoefficients | LogicEvaluation class SimplifyMapper(LokiIdentityMapper): @@ -566,6 +569,36 @@ def map_quotient(self, expr, *args, **kwargs): map_parenthesised_mul = map_product map_parenthesised_div = map_quotient + def map_comparison(self, expr, *args, **kwargs): + left = self.rec(expr.left, *args, **kwargs) + right = self.rec(expr.right, *args, **kwargs) + + if self.enabled_simplifications & Simplification.LogicEvaluation: + if is_constant(left) and is_constant(right): + if expr.operator == '==' and left == right: + return sym.LogicLiteral('True') + return sym.LogicLiteral('False') + + return sym.Comparison(operator=expr.operator, left=left, right=right) + + def map_logical_and(self, expr, *args, **kwargs): + children = tuple(self.rec(child, *args, **kwargs) for child in expr.children) + if all(isinstance(c, sym.LogicLiteral) for c in children): + if all(c == 'True' for c in children): + return sym.LogicLiteral('True') + return sym.LogicLiteral('False') + + return sym.LogicalAnd(children) + + def map_logical_or(self, expr, *args, **kwargs): + children = tuple(self.rec(child, *args, **kwargs) for child in expr.children) + if all(isinstance(c, sym.LogicLiteral) for c in children): + if any(c == 'True' for c in children): + return sym.LogicLiteral('True') + return sym.LogicLiteral('False') + + return sym.LogicalOr(children) + def simplify(expr, enabled_simplifications=Simplification.ALL): """ diff --git a/loki/transform/__init__.py b/loki/transform/__init__.py index 5e28bbcb6..95c0902e8 100644 --- a/loki/transform/__init__.py +++ b/loki/transform/__init__.py @@ -21,3 +21,4 @@ from loki.transform.transform_parametrise import * # noqa from loki.transform.transform_extract_contained_procedures import * # noqa from loki.transform.transform_sequence_association import * # noqa +from loki.transform.transform_dead_code import * # noqa diff --git a/loki/transform/transform_dead_code.py b/loki/transform/transform_dead_code.py new file mode 100644 index 000000000..934ebf554 --- /dev/null +++ b/loki/transform/transform_dead_code.py @@ -0,0 +1,68 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +""" +Collection of utilities to perform Dead Code Elimination. +""" +from loki.visitors import Transformer +from loki.expression.symbolic import simplify +from loki.tools import flatten, as_tuple +from loki.ir import Conditional + + +__all__ = ['dead_code_elimination', 'DeadCodeEliminationTransformer'] + + +def dead_code_elimination(routine, use_simplify=True): + """ + Perform Dead Code Elimination on the given :any:`Subroutine` object. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine to which to apply dead code elimination. + simplify : boolean + Use :any:`simplify` when evaluating expressions for branch pruning. + """ + + transformer = DeadCodeEliminationTransformer(use_simplify=use_simplify) + routine.body = transformer.visit(routine.body) + + +class DeadCodeEliminationTransformer(Transformer): + """ + :any:`Transformer` class that removes provably unreachable code paths. + + The pirmary modification performed is to prune individual code branches + under :any:`Conditional` nodes. + + Parameters + ---------- + simplify : boolean + Use :any:`simplify` when evaluating expressions for branch pruning. + """ + + def __init__(self, use_simplify=True, **kwargs): + super().__init__(**kwargs) + self.use_simplify = use_simplify + + def visit_Conditional(self, o, **kwargs): + condition = self.visit(o.condition, **kwargs) + body = as_tuple(flatten(as_tuple(self.visit(o.body, **kwargs)))) + else_body = as_tuple(flatten(as_tuple(self.visit(o.else_body, **kwargs)))) + + if self.use_simplify: + condition = simplify(condition) + + if condition == 'True': + return body + + if condition == 'False': + return else_body + + has_elseif = o.has_elseif and else_body and isinstance(else_body[0], Conditional) + return self._rebuild(o, tuple((condition,) + (body,) + (else_body,)), has_elseif=has_elseif) diff --git a/loki/transform/transform_inline.py b/loki/transform/transform_inline.py index 3f09e6ed0..edf59e2e1 100644 --- a/loki/transform/transform_inline.py +++ b/loki/transform/transform_inline.py @@ -10,6 +10,8 @@ """ +from collections import defaultdict + from loki.expression import ( FindVariables, FindInlineCalls, FindLiterals, SubstituteExpressions, LokiIdentityMapper @@ -20,11 +22,15 @@ from loki.visitors import Transformer, FindNodes from loki.tools import as_tuple from loki.logging import warning, error +from loki.pragma_utils import pragmas_attached, is_loki_pragma + +from loki.transform.transform_utilities import single_variable_declaration __all__ = [ 'inline_constant_parameters', 'inline_elemental_functions', - 'inline_member_procedures' + 'inline_internal_procedures', 'inline_member_procedures', + 'inline_marked_subroutines' ] @@ -190,24 +196,18 @@ def inline_elemental_functions(routine): routine.spec = Transformer(import_map).visit(routine.spec) -def inline_member_routine(routine, member): +def map_call_to_procedure_body(call, caller): """ - Inline an individual member :any:`Subroutine` at source level. - - This will replace all :any:`Call` objects to the specified - subroutine with an adjusted equivalent of the member routines' - body. For this, argument matching, including partial dimension - matching for array references is performed, and all - member-specific declarations are hoisted to the containing - :any:`Subroutine`. + Resolve arguments of a call and map to the called procedure body. Parameters ---------- - routine : :any:`Subroutine` - The subroutine in which to inline all calls to the member routine - member : :any:`Subroutine` - The contained member subroutine to be inlined in the parent + call : :any:`CallStatment` or :any:`InlineCall` + Call object that defines the argument mapping + caller : :any:`Subroutine` + Procedure (scope) into which the callee's body gets mapped """ + # pylint: disable=import-outside-toplevel,cyclic-import from loki.transform import recursive_expression_map_update @@ -229,99 +229,206 @@ def _map_unbound_dims(var, val): return val.clone(dimensions=tuple(new_dimensions)) - # Prevent shadowing of member variables by renaming them a priori + # Get callee from the procedure type + callee = call.routine + if callee is BasicType.DEFERRED: + error( + '[Loki::TransformInline] Need procedure definition to resolve ' + f'call to {call.name} from {caller}' + ) + raise RuntimeError('Procedure definition not found! ') + + argmap = {} + callee_vars = FindVariables().visit(callee.body) + + # Match dimension indexes between the argument and the given value + # for all occurences of the argument in the body + for arg, val in call.arg_map.items(): + if isinstance(arg, sym.Array): + # Resolve implicit dimension ranges of the passed value, + # eg. when passing a two-dimensional array `a` as `call(arg=a)` + # Check if val is a DeferredTypeSymbol, as it does not have a `dimensions` attribute + if not isinstance(val, sym.DeferredTypeSymbol) and val.dimensions: + qualified_value = val + else: + qualified_value = val.clone( + dimensions=tuple(sym.Range((None, None)) for _ in arg.shape) + ) + + # If sequence association (scalar-to-array argument passing) is used, + # we cannot determine the right re-mapped iteration space, so we bail here! + if not any(isinstance(d, sym.Range) for d in qualified_value.dimensions): + error( + '[Loki::TransformInline] Cannot find free dimension resolving ' + f' array argument for value "{qualified_value}"' + ) + raise RuntimeError( + f'[Loki::TransformInline] Cannot resolve procedure call to {call.name}' + ) + arg_vars = tuple(v for v in callee_vars if v.name == arg.name) + argmap.update((v, _map_unbound_dims(v, qualified_value)) for v in arg_vars) + else: + argmap[arg] = val + + # Deal with PRESENT check for optional arguments + present_checks = tuple( + check for check in FindInlineCalls().visit(callee.body) if check.function == 'PRESENT' + ) + present_map = { + check: sym.Literal('.true.') if check.arguments[0] in call.arg_map else sym.Literal('.false.') + for check in present_checks + } + argmap.update(present_map) + + # Recursive update of the map in case of nested variables to map + argmap = recursive_expression_map_update(argmap, max_iterations=10) + + # Substitute argument calls into a copy of the body + callee_body = SubstituteExpressions(argmap, rebuild_scopes=True).visit( + callee.body.body, scope=caller + ) + + # Inline substituted body within a pair of marker comments + comment = Comment(f'! [Loki] inlined child subroutine: {callee.name}') + c_line = Comment('! =========================================') + return (comment, c_line) + as_tuple(callee_body) + (c_line, ) + + +def inline_subroutine_calls(routine, calls, callee, allowed_aliases=None): + """ + Inline a set of call to an individual :any:`Subroutine` at source level. + + This will replace all :any:`Call` objects to the specified + subroutine with an adjusted equivalent of the member routines' + body. For this, argument matching, including partial dimension + matching for array references is performed, and all + member-specific declarations are hoisted to the containing + :any:`Subroutine`. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine in which to inline all calls to the member routine + calls : tuple or list of :any:`CallStatement` + callee : :any:`Subroutine` + The called target subroutine to be inlined in the parent + allowed_aliases : tuple or list of str or :any:`Expression`, optional + List of variables that will not be renamed in the parent scope, even + if they alias with a local declaration. + """ + allowed_aliases = as_tuple(allowed_aliases) + + # Ensure we process sets of calls to the same callee + assert all(call.routine == callee for call in calls) + + # Prevent shadowing of callee's variables by renaming them a priori parent_variables = routine.variable_map - duplicate_locals = tuple( - v for v in member.variables - if v.name in parent_variables and v.name.lower() not in member._dummies + duplicates = tuple( + v for v in callee.variables + if v.name in parent_variables and v.name.lower() not in callee._dummies ) + # Filter out allowed aliases to prevent suffixing + duplicates = tuple(v for v in duplicates if v.symbol not in allowed_aliases) shadow_mapper = SubstituteExpressions( - {v: v.clone(name=f'{member.name}_{v.name}') for v in duplicate_locals} + {v: v.clone(name=f'{callee.name}_{v.name}') for v in duplicates} ) - member.spec = shadow_mapper.visit(member.spec) + callee.spec = shadow_mapper.visit(callee.spec) var_map = {} - duplicate_locals_names = {dl.name.lower() for dl in duplicate_locals} - for v in FindVariables(unique=False).visit(member.body): - if v.name.lower() in duplicate_locals_names: - var_map[v] = v.clone(name=f'{member.name}_{v.name}') - member.body = SubstituteExpressions(var_map).visit(member.body) + duplicate_names = {dl.name.lower() for dl in duplicates} + for v in FindVariables(unique=False).visit(callee.body): + if v.name.lower() in duplicate_names: + var_map[v] = v.clone(name=f'{callee.name}_{v.name}') + callee.body = SubstituteExpressions(var_map).visit(callee.body) + + # Separate allowed aliases from other variables to ensure clean hoisting + if allowed_aliases: + single_variable_declaration(callee, variables=allowed_aliases) # Get local variable declarations and hoist them - decls = FindNodes(VariableDeclaration).visit(member.spec) - decls = tuple(d for d in decls if all(s.name.lower() not in member._dummies for s in d.symbols)) + decls = FindNodes(VariableDeclaration).visit(callee.spec) + decls = tuple(d for d in decls if all(s.name.lower() not in callee._dummies for s in d.symbols)) decls = tuple(d for d in decls if all(s not in routine.variables for s in d.symbols)) + # Rescope the declaration symbols + decls = tuple(d.clone(symbols=tuple(s.clone(scope=routine) for s in d.symbols)) for d in decls) routine.spec.append(decls) - call_map = {} - for call in FindNodes(CallStatement).visit(routine.body): - if call.routine == member: - argmap = {} - member_vars = FindVariables().visit(member.body) - - # Match dimension indexes between the argument and the given value - # for all occurences of the argument in the body - for arg, val in call.arg_map.items(): - if isinstance(arg, sym.Array): - # Resolve implicit dimension ranges of the passed value, - # eg. when passing a two-dimensional array `a` as `call(arg=a)` - # Check if val is a DeferredTypeSymbol, as it does not have a `dimensions` attribute - if not isinstance(val, sym.DeferredTypeSymbol) and val.dimensions: - qualified_value = val - else: - qualified_value = val.clone( - dimensions=tuple(sym.Range((None, None)) for _ in arg.shape) - ) - - # If sequence association (scalar-to-array argument passing) is used, - # we cannot determine the right re-mapped iteration space, so we bail here! - if not any(isinstance(d, sym.Range) for d in qualified_value.dimensions): - error( - '[Loki::TransformInline] Cannot find free dimension resolving ' - f' array argument for value "{qualified_value}"' - ) - raise RuntimeError('[Loki::TransformInline] Unable to resolve member subroutine call') - arg_vars = tuple(v for v in member_vars if v.name == arg.name) - argmap.update((v, _map_unbound_dims(v, qualified_value)) for v in arg_vars) - else: - argmap[arg] = val - - # Recursive update of the map in case of nested variables to map - argmap = recursive_expression_map_update(argmap, max_iterations=10) - - # Substitute argument calls into a copy of the body - member_body = SubstituteExpressions(argmap, rebuild_scopes=True).visit( - member.body.body, scope=routine - ) - - # Inline substituted body within a pair of marker comments - comment = Comment(f'! [Loki] inlined member subroutine: {member.name}') - c_line = Comment('! =========================================') - call_map[call] = (comment, c_line) + as_tuple(member_body) + (c_line, ) + # Resolve the call by mapping arguments into the called procedure's body + call_map = { + call: map_call_to_procedure_body(call, caller=routine) for call in calls + } - # Replace calls to member with the member's body + # Replace calls to child procedure with the child's body routine.body = Transformer(call_map).visit(routine.body) - # Can't use transformer to replace subroutine, so strip it manually - contains_body = tuple(n for n in routine.contains.body if not n == member) - routine.contains._update(body=contains_body) -def inline_member_procedures(routine): +def inline_internal_procedures(routine, allowed_aliases=None): """ - Inline all member subroutines contained in an individual :any:`Subroutine`. + Inline internal subroutines contained in an individual :any:`Subroutine`. - Please note that member functions are not yet supported! + Please note that internal functions are not yet supported! Parameters ---------- routine : :any:`Subroutine` The subroutine in which to inline all member routines + allowed_aliases : tuple or list of str or :any:`Expression`, optional + List of variables that will not be renamed in the parent scope, even + if they alias with a local declaration. """ # Run through all members and invoke individual inlining transforms - for member in routine.members: - if member.is_function: + for child in routine.members: + if child.is_function: # TODO: Implement for functions!!! - warning('[Loki::inline] Inlining member functions is not yet supported, only subroutines!') + warning('[Loki::inline] Inlining internal functions is not yet supported, only subroutines!') else: - inline_member_routine(routine, member) + calls = tuple( + call for call in FindNodes(CallStatement).visit(routine.body) + if call.routine == child + ) + inline_subroutine_calls(routine, calls, child, allowed_aliases=allowed_aliases) + + # Can't use transformer to replace subroutine, so strip it manually + contains_body = tuple(n for n in routine.contains.body if not n == child) + routine.contains._update(body=contains_body) + + +inline_member_procedures = inline_internal_procedures + + +def inline_marked_subroutines(routine, allowed_aliases=None): + """ + Inline :any:`Subroutine` objects guided by pragma annotations. + + When encountering :any:`CallStatement` objects that are marked with a + ``!$loki inline`` pragma, this utility will attempt to replace the call + with the body of the called procedure and remap all passed arguments + into the calling procedures scope. + + Please note that this utility requires :any:`CallStatement` objects + to be "enriched" with external type information. + + Parameters + ---------- + routine : :any:`Subroutine` + The subroutine in which to look for pragma-marked procedures to inline + allowed_aliases : tuple or list of str or :any:`Expression`, optional + List of variables that will not be renamed in the parent scope, even + if they alias with a local declaration. + """ + + with pragmas_attached(routine, node_type=CallStatement): + + # Group the marked calls by callee routine + call_sets = defaultdict(list) + for call in FindNodes(CallStatement).visit(routine.body): + if is_loki_pragma(call.pragma, starts_with='inline'): + call_sets[call.routine].append(call) + + # Trigger per-call inlining on collected sets + for callee, calls in call_sets.items(): + if callee: # Skip the unattached calls (collected under None) + inline_subroutine_calls( + routine, calls, callee, allowed_aliases=allowed_aliases + ) diff --git a/scripts/loki_transform.py b/scripts/loki_transform.py index 7827001d2..94ed11b88 100644 --- a/scripts/loki_transform.py +++ b/scripts/loki_transform.py @@ -107,15 +107,19 @@ def cli(debug): help="Remove derived-type arguments and replace with canonical arguments") @click.option('--inline-members/--no-inline-members', default=False, help='Inline member functions for SCC-class transformations.') +@click.option('--inline-marked/--no-inline-marked', default=True, + help='Inline pragma-marked subroutines for SCC-class transformations.') @click.option('--resolve-sequence-association/--no-resolve-sequence-association', default=False, help='Replace array arguments passed as scalars with arrays.') @click.option('--derive-argument-array-shape/--no-derive-argument-array-shape', default=False, help="Recursively derive explicit shape dimension for argument arrays") +@click.option('--eliminate-dead-code/--no-eliminate-dead-code', default=True, + help='Perform dead code elimination, where unreachable branches are trimmed from the code.') def convert( mode, config, build, source, header, cpp, directive, include, define, omni_include, xmod, data_offload, remove_openmp, assume_deviceptr, frontend, trim_vector_sections, - global_var_offload, remove_derived_args, inline_members, resolve_sequence_association, - derive_argument_array_shape + global_var_offload, remove_derived_args, inline_members, inline_marked, + resolve_sequence_association, derive_argument_array_shape, eliminate_dead_code ): """ Batch-processing mode for Fortran-to-Fortran transformations that @@ -210,7 +214,9 @@ def convert( # Apply the basic SCC transformation set scheduler.process( SCCBaseTransformation( horizontal=horizontal, directive=directive, - inline_members=inline_members, resolve_sequence_association=resolve_sequence_association + inline_members=inline_members, inline_marked=inline_marked, + resolve_sequence_association=resolve_sequence_association, + eliminate_dead_code=eliminate_dead_code )) scheduler.process( SCCDevectorTransformation( horizontal=horizontal, trim_vector_sections=trim_vector_sections diff --git a/tests/test_symbolic.py b/tests/test_symbolic.py index a98f660c8..92d489943 100644 --- a/tests/test_symbolic.py +++ b/tests/test_symbolic.py @@ -186,6 +186,25 @@ def test_simplify_collect_coefficients(source, ref): assert str(expr) == ref +@pytest.mark.skipif(not HAVE_FP, reason='Fparser not available') +@pytest.mark.parametrize('source, ref', [ + ('1 == 1', 'True'), + ('2 == 1', 'False'), + ('1 + 1 == 2', '1 + 1 == 2'), # Not true without integer arithmetic + ('.true. .and. .true.', 'True'), + ('.true. .and. .false.', 'False'), + ('.true. .or. .false.', 'True'), + ('.false. .or. .false.', 'False'), + ('2 == 1 .and. 1 == 1', 'False'), + ('2 == 1 .or. 1 == 1', 'True'), +]) +def test_simplify_logic_evaluation(source, ref): + scope = Scope() + expr = parse_fparser_expression(source, scope) + expr = simplify(expr, enabled_simplifications=Simplification.LogicEvaluation) + assert str(expr) == ref + + @pytest.mark.skipif(not HAVE_FP, reason='Fparser not available') @pytest.mark.parametrize('source, ref', [ ('5 * (4 + 3 * (2 + 1) )', '65'), @@ -201,7 +220,9 @@ def test_simplify_collect_coefficients(source, ref): ('1*a*b + 0*a*b', 'a*b'), ('n+(((-1)*1)*n)', '0'), ('5 + a * (3 - b * (2 + c) / 7) * 5 - 4', '1 + 15*a - 10*a*b / 7 - 5*a*b*c / 7'), - ('(5 + 3) * a - 8 * a / 2 + a * ((7 - 1) / 3)', '6*a') + ('(5 + 3) * a - 8 * a / 2 + a * ((7 - 1) / 3)', '6*a'), + ('(5 + 3) == 8', 'True'), + ('42 == 666', 'False'), ]) def test_simplify(source,ref): scope = Scope() diff --git a/tests/test_transform_dead_code.py b/tests/test_transform_dead_code.py new file mode 100644 index 000000000..ec3c73740 --- /dev/null +++ b/tests/test_transform_dead_code.py @@ -0,0 +1,129 @@ +# (C) Copyright 2018- ECMWF. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from conftest import available_frontends +from loki import Subroutine, FindNodes, Conditional, Assignment, OMNI +from loki.transform import dead_code_elimination + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_dead_code_conditional(frontend): + """ + Test correct elimination of unreachable conditional branches. + """ + fcode = """ +subroutine test_dead_code_conditional(a, b, flag) + real(kind=8), intent(inout) :: a, b + logical, intent(in) :: flag + + if (flag) then + if (1 == 6) then + a = a + b + else + b = b + 2.0 + end if + + if (2 == 2) then + b = b + a + else + a = a + 3.0 + end if + + if (1 == 2) then + b = b + a + elseif (3 == 3) then + a = a + b + else + a = a + 6.0 + end if + + end if +end subroutine test_dead_code_conditional +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + # Please note that nested conditionals (elseif) counts as two + assert len(FindNodes(Conditional).visit(routine.body)) == 5 + assert len(FindNodes(Assignment).visit(routine.body)) == 7 + + dead_code_elimination(routine) + + conditionals = FindNodes(Conditional).visit(routine.body) + assert len(conditionals) == 1 + assert conditionals[0].condition == 'flag' + assigns = FindNodes(Assignment).visit(routine.body) + assert len(assigns) == 3 + assert assigns[0].lhs == 'b' and assigns[0].rhs == 'b + 2.0' + assert assigns[1].lhs == 'b' and assigns[1].rhs == 'b + a' + assert assigns[2].lhs == 'a' and assigns[2].rhs == 'a + b' + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_transform_dead_code_conditional_nested(frontend): + """ + Test correct elimination of unreachable branches in nested conditionals. + """ + fcode = """ +subroutine test_dead_code_conditional(a, b, flag) + real(kind=8), intent(inout) :: a, b + logical, intent(in) :: flag + + if (1 == 2) then + a = a + 5 + elseif (flag) then + b = b + 4 + else + b = a + 3 + end if + + if (a > 2.0) then + a = a + 5.0 + elseif (2 == 3) then + a = a + 3.0 + else + a = a + 1.0 + endif + + if (a > 2.0) then + a = a + 5.0 + elseif (2 == 3) then + a = a + 3.0 + elseif (a > 1.0) then + a = a + 2.0 + else + a = a + 1.0 + endif +end subroutine test_dead_code_conditional +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + # Please note that nested conditionals (elseif) counts as two + assert len(FindNodes(Conditional).visit(routine.body)) == 7 + assert len(FindNodes(Assignment).visit(routine.body)) == 10 + + dead_code_elimination(routine) + + conditionals = FindNodes(Conditional).visit(routine.body) + assert len(conditionals) == 4 + assert conditionals[0].condition == 'flag' + assert not conditionals[0].has_elseif + assert conditionals[1].condition == 'a > 2.0' + assert not conditionals[1].has_elseif + assert conditionals[2].condition == 'a > 2.0' + if not frontend == OMNI: # OMNI does not get elseifs right + assert conditionals[2].has_elseif + assert conditionals[3].condition == 'a > 1.0' + assert not conditionals[3].has_elseif + assigns = FindNodes(Assignment).visit(routine.body) + assert len(assigns) == 7 + assert assigns[0].lhs == 'b' and assigns[0].rhs == 'b + 4' + assert assigns[1].lhs == 'b' and assigns[1].rhs == 'a + 3' + assert assigns[2].lhs == 'a' and assigns[2].rhs == 'a + 5.0' + assert assigns[3].lhs == 'a' and assigns[3].rhs == 'a + 1.0' + assert assigns[4].lhs == 'a' and assigns[4].rhs == 'a + 5.0' + assert assigns[5].lhs == 'a' and assigns[5].rhs == 'a + 2.0' + assert assigns[6].lhs == 'a' and assigns[6].rhs == 'a + 1.0' diff --git a/tests/test_transform_inline.py b/tests/test_transform_inline.py index 30dcd683e..ec70d1113 100644 --- a/tests/test_transform_inline.py +++ b/tests/test_transform_inline.py @@ -12,12 +12,14 @@ from conftest import jit_compile, jit_compile_lib, available_frontends from loki import ( Builder, Module, Subroutine, FindNodes, Import, FindVariables, - CallStatement, Loop, BasicType, DerivedType, Associate, OMNI + CallStatement, Loop, BasicType, DerivedType, Associate, OMNI, + Conditional, FindInlineCalls ) from loki.ir import Assignment from loki.transform import ( inline_elemental_functions, inline_constant_parameters, - replace_selected_kind, inline_member_procedures + replace_selected_kind, inline_member_procedures, + inline_marked_subroutines ) from loki.expression import symbols as sym @@ -517,6 +519,57 @@ def test_inline_member_routines_variable_shadowing(frontend): assert assign[2].lhs == 'y' and assign[2].rhs == 'y + sum(inner_x)' +@pytest.mark.parametrize('frontend', available_frontends()) +def test_inline_internal_routines_aliasing_declaration(frontend): + """ + Test declaration splitting when inlining internal procedures. + """ + fcode = """ +subroutine outer() + integer :: z + integer :: jlon + z = 0 + jlon = 0 + + call inner(z) + + jlon = z + 4 +contains + subroutine inner(z) + integer, intent(inout) :: z + integer :: jlon, jg ! These two need to get separated + jlon = 1 + jg = 2 + z = jlon + jg + end subroutine inner +end subroutine outer + """ + routine = Subroutine.from_source(fcode, frontend=frontend) + + # Check outer and inner variables + assert len(routine.variable_map) == 2 + assert 'z' in routine.variable_map + assert 'jlon' in routine.variable_map + + assert len(routine['inner'].variable_map) == 3 + assert 'z' in routine['inner'].variable_map + assert 'jlon' in routine['inner'].variable_map + assert 'jg' in routine['inner'].variable_map + + inline_member_procedures(routine, allowed_aliases=('jlon',)) + + assert len(routine.variable_map) == 3 + assert 'z' in routine.variable_map + assert 'jlon' in routine.variable_map + assert 'jg' in routine.variable_map + + assigns = FindNodes(Assignment).visit(routine.body) + assert len(assigns) == 6 + assert assigns[2].lhs == 'jlon' and assigns[2].rhs == '1' + assert assigns[3].lhs == 'jg' and assigns[3].rhs == '2' + assert assigns[4].lhs == 'z' and assigns[4].rhs == 'jlon + jg' + + @pytest.mark.parametrize('frontend', available_frontends()) def test_inline_member_routines_sequence_assoc(frontend): """ @@ -607,3 +660,148 @@ def test_inline_member_routines_with_associate(frontend): assocs = FindNodes(Associate).visit(routine.body) assert len(assocs) == 2 + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_inline_marked_subroutines(frontend): + """ Test subroutine inlining via marker pragmas. """ + + fcode_driver = """ +subroutine test_pragma_inline(a, b) + use util_mod, only: add_one, add_a_to_b + implicit none + + real(kind=8), intent(inout) :: a(3), b(3) + integer, parameter :: n = 3 + integer :: i + + do i=1, n + !$loki inline + call add_one(a(i)) + end do + + !$loki inline + call add_a_to_b(a(:), b(:), 3) + + do i=1, n + call add_one(b(i)) + end do + +end subroutine test_pragma_inline + """ + + fcode_module = """ +module util_mod +implicit none + +contains + subroutine add_one(a) + real(kind=8), intent(inout) :: a + a = a + 1 + end subroutine add_one + + subroutine add_a_to_b(a, b, n) + real(kind=8), intent(inout) :: a(:), b(:) + integer, intent(in) :: n + integer :: i + + do i = 1, n + a(i) = a(i) + b(i) + end do + end subroutine add_a_to_b +end module util_mod +""" + module = Module.from_source(fcode_module, frontend=frontend) + driver = Subroutine.from_source(fcode_driver, frontend=frontend) + driver.enrich(module) + + calls = FindNodes(CallStatement).visit(driver.body) + assert calls[0].routine == module['add_one'] + assert calls[1].routine == module['add_a_to_b'] + assert calls[2].routine == module['add_one'] + + inline_marked_subroutines(routine=driver, allowed_aliases=('I',)) + + # Check inlined loops and assignments + assert len(FindNodes(Loop).visit(driver.body)) == 3 + assign = FindNodes(Assignment).visit(driver.body) + assert len(assign) == 2 + assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1' + assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + b(i)' + + # Check that the last call is left untouched + calls = FindNodes(CallStatement).visit(driver.body) + assert len(calls) == 1 + assert calls[0].routine.name == 'add_one' + assert calls[0].arguments == ('b(i)',) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_inline_marked_routine_with_optionals(frontend): + """ Test subroutine inlining via marker pragmas with omitted optionals. """ + + fcode_driver = """ +subroutine test_pragma_inline_optionals(a, b) + use util_mod, only: add_one + implicit none + + real(kind=8), intent(inout) :: a(3), b(3) + integer, parameter :: n = 3 + integer :: i + + do i=1, n + !$loki inline + call add_one(a(i), two=2.0) + end do + + do i=1, n + !$loki inline + call add_one(b(i)) + end do + +end subroutine test_pragma_inline_optionals + """ + + fcode_module = """ +module util_mod +implicit none + +contains + subroutine add_one(a, two) + real(kind=8), intent(inout) :: a + real(kind=8), optional, intent(inout) :: two + a = a + 1 + + if (present(two)) then + a = a + two + end if + end subroutine add_one +end module util_mod +""" + module = Module.from_source(fcode_module, frontend=frontend) + driver = Subroutine.from_source(fcode_driver, frontend=frontend) + driver.enrich(module) + + calls = FindNodes(CallStatement).visit(driver.body) + assert calls[0].routine == module['add_one'] + assert calls[1].routine == module['add_one'] + + inline_marked_subroutines(routine=driver) + + # Check inlined loops and assignments + assert len(FindNodes(Loop).visit(driver.body)) == 2 + assign = FindNodes(Assignment).visit(driver.body) + assert len(assign) == 4 + assert assign[0].lhs == 'a(i)' and assign[0].rhs == 'a(i) + 1' + assert assign[1].lhs == 'a(i)' and assign[1].rhs == 'a(i) + 2.0' + assert assign[2].lhs == 'b(i)' and assign[2].rhs == 'b(i) + 1' + # TODO: This is a problem, since it's not declared anymore + assert assign[3].lhs == 'b(i)' and assign[3].rhs == 'b(i) + two' + + # Check that the PRESENT checks have been resolved + assert len(FindNodes(CallStatement).visit(driver.body)) == 0 + assert len(FindInlineCalls().visit(driver.body)) == 0 + checks = FindNodes(Conditional).visit(driver.body) + assert len(checks) == 2 + assert checks[0].condition == 'True' + assert checks[1].condition == 'False' diff --git a/transformations/tests/test_single_column_coalesced.py b/transformations/tests/test_single_column_coalesced.py index 7cccbe50a..4838388e2 100644 --- a/transformations/tests/test_single_column_coalesced.py +++ b/transformations/tests/test_single_column_coalesced.py @@ -1774,7 +1774,7 @@ def test_single_column_coalesced_inline_and_sequence_association(frontend, horiz with pytest.raises(RuntimeError) as e_info: scc_transform.apply(routine, role='kernel') assert(e_info.exconly() == - 'RuntimeError: [Loki::TransformInline] Unable to resolve member subroutine call') + 'RuntimeError: [Loki::TransformInline] Cannot resolve procedure call to contained_kernel') #Check that the call is properly modified elif (not inline_members and resolve_sequence_association): diff --git a/transformations/transformations/single_column_coalesced.py b/transformations/transformations/single_column_coalesced.py index 5bf2f9cb1..3bc92a186 100644 --- a/transformations/transformations/single_column_coalesced.py +++ b/transformations/transformations/single_column_coalesced.py @@ -7,7 +7,11 @@ import re from loki.expression import symbols as sym -from loki.transform import resolve_associates, inline_member_procedures, transform_sequence_association +from loki.transform import ( + resolve_associates, inline_member_procedures, + inline_marked_subroutines, transform_sequence_association, + dead_code_elimination +) from loki import ( Transformation, FindNodes, Transformer, info, pragmas_attached, as_tuple, flatten, ir, FindExpressions, @@ -38,16 +42,28 @@ class methods can be called directly. ``'openacc'`` or ``None``. inline_members : bool Enable full source-inlining of member subroutines; default: False. + inline_marked : bool + Enable inlining for subroutines marked with ``!$loki inline``; default: True. + resolve_sequence_association : bool + Replace scalars that are passed to array arguments with array ranges; default: False. + use_dead_code_elimination : bool + Perform dead code elimination, where unreachable branches are trimmed from the code. """ - def __init__(self, horizontal, directive=None, inline_members=False, resolve_sequence_association=False): + def __init__( + self, horizontal, directive=None, inline_members=False, + inline_marked=True, resolve_sequence_association=False, + eliminate_dead_code=True + ): self.horizontal = horizontal assert directive in [None, 'openacc'] self.directive = directive self.inline_members = inline_members + self.inline_marked = inline_marked self.resolve_sequence_association = resolve_sequence_association + self.eliminate_dead_code = eliminate_dead_code @classmethod def check_routine_pragmas(cls, routine, directive): @@ -296,14 +312,23 @@ def process_kernel(self, routine): if self.resolve_sequence_association: transform_sequence_association(routine) - # Perform full source-inlining for member subroutines if so requested + # Perform full source-inlining for member subroutines if self.inline_members: - inline_member_procedures(routine) + inline_member_procedures(routine, allowed_aliases=(self.horizontal.index,)) + + # Perform full source-inlining for pragma-marked subroutines + if self.inline_marked: + # When inlining we allow the horizontal dimension to alias, so that + # the de/re-vectorisation captures the shared vector dimension. + inline_marked_subroutines(routine, allowed_aliases=(self.horizontal.index,)) # Associates at the highest level, so they don't interfere # with the sections we need to do for detecting subroutine calls resolve_associates(routine) + if self.eliminate_dead_code: + dead_code_elimination(routine) + # Resolve WHERE clauses self.resolve_masked_stmts(routine, loop_variable=v_index)