Skip to content

Commit

Permalink
Merge pull request #205 from ecmwf-ifs/naml-pragma-inline-remove-imports
Browse files Browse the repository at this point in the history
Recursive inlining via InlineTransform and associated fixes
  • Loading branch information
reuterbal authored Jan 26, 2024
2 parents 44fcf08 + 63a4434 commit 903798f
Show file tree
Hide file tree
Showing 16 changed files with 632 additions and 321 deletions.
3 changes: 1 addition & 2 deletions loki/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from loki.transform.transformation import * # noqa
from loki.transform.transform_utilities import * # noqa
from loki.transform.transform_array_indexing import * # noqa
from loki.transform.transform_associates import * # noqa
from loki.transform.transform_inline import * # noqa
from loki.transform.transform_loop import * # noqa
from loki.transform.transform_region import * # noqa
Expand All @@ -20,5 +19,5 @@
from loki.transform.transform_hoist_variables import * # noqa
from loki.transform.transform_parametrise import * # noqa
from loki.transform.transform_extract_contained_procedures import * # noqa
from loki.transform.transform_sequence_association import * # noqa
from loki.transform.transform_dead_code import * # noqa
from loki.transform.transform_sanitise import * # noqa
2 changes: 1 addition & 1 deletion loki/transform/fortran_c_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
resolve_vector_notation, normalize_array_shape_and_access,
flatten_arrays
)
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_sanitise import resolve_associates
from loki.transform.transform_utilities import (
convert_to_lower_case, replace_intrinsics, sanitise_imports
)
Expand Down
2 changes: 1 addition & 1 deletion loki/transform/fortran_python_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from loki.transform.transform_array_indexing import (
shift_to_zero_indexing, invert_array_indices, normalize_range_indexing
)
from loki.transform.transform_associates import resolve_associates
from loki.transform.transform_sanitise import resolve_associates
from loki.transform.transform_utilities import (
convert_to_lower_case, replace_intrinsics
)
Expand Down
69 changes: 0 additions & 69 deletions loki/transform/transform_associates.py

This file was deleted.

143 changes: 132 additions & 11 deletions loki/transform/transform_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from loki.tools import as_tuple
from loki.logging import warning, error
from loki.pragma_utils import pragmas_attached, is_loki_pragma
from loki.subroutine import Subroutine

from loki.transform.transformation import Transformation
from loki.transform.transform_dead_code import dead_code_elimination
from loki.transform.transform_utilities import (
single_variable_declaration,
recursive_expression_map_update
Expand All @@ -32,10 +35,89 @@
__all__ = [
'inline_constant_parameters', 'inline_elemental_functions',
'inline_internal_procedures', 'inline_member_procedures',
'inline_marked_subroutines'
'inline_marked_subroutines', 'InlineTransformation'
]


class InlineTransformation(Transformation):
"""
:any:`Transformation` class to apply several types of source inlining
when batch-processing large source trees via the :any:`Scheduler`.
Parameters
----------
inline_constants : bool
Replace instances of variables with known constant values by
:any:`Literal` (see :any:`inline_constant_parameters`); default: False.
inline_elementals : bool
Replaces :any:`InlineCall` expression to elemental functions
with the called function's body (see :any:`inline_elemental_functions`);
default: True.
inline_internals : bool
Inline internal procedure (see :any:`inline_internal_procedures`);
default: False.
inline_marked : bool
Inline :any:`Subroutine` objects marked by pragma annotations
(see :any:`inline_marked_subroutines`); default: True.
eliminate_dead_code : bool
Perform dead code elimination, where unreachable branches are
trimmed from the code (see :any:`dead_code_elimination`); default: True
allowed_aliases : tuple or list of str or :any:`Expression`, optional
List of variables that will not be renamed in the parent scope during
internal and pragma-driven inlining.
remove_imports : bool
Strip unused import symbols after pragma-inlining (optional, default: True)
external_only : bool, optional
Do not replace variables declared in the local scope when
inlining constants (default: True)
"""

# Ensure correct recursive inlining by traversing from the leaves
reverse_traversal = True

def __init__(
self, inline_constants=False, inline_elementals=True,
inline_internals=False, inline_marked=True,
eliminate_dead_code=True, allowed_aliases=None,
remove_imports=True, external_only=True
):
self.inline_constants = inline_constants
self.inline_elementals = inline_elementals
self.inline_internals = inline_internals
self.inline_marked = inline_marked

self.eliminate_dead_code = eliminate_dead_code

self.allowed_aliases = allowed_aliases
self.remove_imports = remove_imports
self.external_only = external_only

def transform_subroutine(self, routine, **kwargs):

# Replace constant parameter variables with explicit values
if self.inline_constants:
inline_constant_parameters(routine, external_only=self.external_only)

# Inline elemental functions
if self.inline_elementals:
inline_elemental_functions(routine)

# Inline internal (contained) procedures
if self.inline_internals:
inline_internal_procedures(routine, allowed_aliases=self.allowed_aliases)

# Inline explicitly pragma-marked subroutines
if self.inline_marked:
inline_marked_subroutines(
routine, allowed_aliases=self.allowed_aliases,
remove_imports=self.remove_imports
)

# After inlining, attempt to trim unreachable code paths
if self.eliminate_dead_code:
dead_code_elimination(routine)


class InlineSubstitutionMapper(LokiIdentityMapper):
"""
An expression mapper that defines symbolic substitution for inlining.
Expand Down Expand Up @@ -101,15 +183,23 @@ def inline_constant_parameters(routine, external_only=True):
"""
Replace instances of variables with known constant values by `Literals`.
:param external_only: Do not replace variables declared in the local scope
Notes
-----
The ``.type.initial`` property is used to derive the replacement
value,a which means for symbols imported from external modules,
the parent :any:`Module` needs to be supplied in the
``definitions`` to the constructor when creating :param:`routine`.
Note, the `.type.initial` property is used to derive the replacement value,
which means for symbols imported from external modules, the parent `Module`
needs to be supplied in the `definitions` to the constructor when creating
:param routine:.
Variables that are replaced are also removed from their
corresponding import statements, with empty import statements
being removed alltogether.
Variables that are replaced are also removed from their corresponding import
statements, with empty import statements being removed alltogether.
Parameters
----------
routine : :any:`Subroutine`
Procedure in which to inline/resolve constant parameters.
external_only : bool, optional
Do not replace variables declared in the local scope (default: True)
"""
# Find all variable instances in spec and body
variables = FindVariables().visit(routine.ir)
Expand Down Expand Up @@ -179,7 +269,10 @@ def inline_elemental_functions(routine):

exprmap = {}
for call in FindInlineCalls().visit(routine.body):
if call.procedure_type is not BasicType.DEFERRED:
if call.procedure_type is BasicType.DEFERRED:
continue

if call.procedure_type.is_function and call.procedure_type.is_elemental:
# Map each call to its substitutions, as defined by the
# recursive inline substitution mapper
exprmap[call] = InlineSubstitutionMapper()(call, scope=routine)
Expand All @@ -193,7 +286,7 @@ def inline_elemental_functions(routine):
# Remove all module imports that have become obsolete now
import_map = {}
for im in FindNodes(Import).visit(routine.spec):
if all(hasattr(s, 'type') and s.type.dtype in removed_functions for s in im.symbols):
if im.symbols and all(s.type.dtype in removed_functions for s in im.symbols):
import_map[im] = None
routine.spec = Transformer(import_map).visit(routine.spec)

Expand Down Expand Up @@ -319,6 +412,7 @@ def inline_subroutine_calls(routine, calls, callee, allowed_aliases=None):

# Ensure we process sets of calls to the same callee
assert all(call.routine == callee for call in calls)
assert isinstance(callee, Subroutine)

# Prevent shadowing of callee's variables by renaming them a priori
parent_variables = routine.variable_map
Expand Down Expand Up @@ -397,7 +491,7 @@ def inline_internal_procedures(routine, allowed_aliases=None):
inline_member_procedures = inline_internal_procedures


def inline_marked_subroutines(routine, allowed_aliases=None):
def inline_marked_subroutines(routine, allowed_aliases=None, remove_imports=True):
"""
Inline :any:`Subroutine` objects guided by pragma annotations.
Expand All @@ -416,19 +510,46 @@ def inline_marked_subroutines(routine, allowed_aliases=None):
allowed_aliases : tuple or list of str or :any:`Expression`, optional
List of variables that will not be renamed in the parent scope, even
if they alias with a local declaration.
remove_imports : bool
Strip unused import symbols after inlining (optional, default: True)
"""

with pragmas_attached(routine, node_type=CallStatement):

# Group the marked calls by callee routine
call_sets = defaultdict(list)
no_call_sets = defaultdict(list)
for call in FindNodes(CallStatement).visit(routine.body):
if call.routine == BasicType.DEFERRED:
continue

if is_loki_pragma(call.pragma, starts_with='inline'):
call_sets[call.routine].append(call)
else:
no_call_sets[call.routine].append(call)

# Trigger per-call inlining on collected sets
for callee, calls in call_sets.items():
if callee: # Skip the unattached calls (collected under None)
inline_subroutine_calls(
routine, calls, callee, allowed_aliases=allowed_aliases
)

# Remove imported symbols that have become obsolete
if remove_imports:
callees = tuple(callee.procedure_symbol for callee in call_sets.keys())
not_inlined = tuple(callee.procedure_symbol for callee in no_call_sets.keys())

import_map = {}
for impt in FindNodes(Import).visit(routine.spec):
# Remove interface header imports
if any(f'{c.name.lower()}.intfb.h' == impt.module for c in callees):
import_map[impt] = None

if any(s.name in callees for s in impt.symbols):
new_symbols = tuple(
s for s in impt.symbols if s.name not in callees or s.name in not_inlined
)
# Remove import if no further symbols used, otherwise clone with new symbols
import_map[impt] = impt.clone(symbols=new_symbols) if new_symbols else None
routine.spec = Transformer(import_map).visit(routine.spec)
Loading

0 comments on commit 903798f

Please sign in to comment.