Skip to content

Commit

Permalink
Merge pull request #198 from ecmwf-ifs/naml-routine-inline-pragma
Browse files Browse the repository at this point in the history
Pragma-driven subroutine inlining and associated utilities
  • Loading branch information
reuterbal authored Dec 20, 2023
2 parents 2f5158a + 7652b95 commit 9f25284
Show file tree
Hide file tree
Showing 10 changed files with 683 additions and 95 deletions.
35 changes: 34 additions & 1 deletion loki/expression/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,13 +496,16 @@ class Simplification(enum.Flag):
Flatten Flatten sub-sums and distribute products.
IntegerArithmetic Perform arithmetic on integer literals (addition and multiplication).
CollectCoefficients Combine summands as far as possible.
LogicEvaluation Resolve logically fully determinate expressions, like ``1 == 1`` or ``1 == 6``
ALL All of the above.
"""
Flatten = enum.auto()
IntegerArithmetic = enum.auto()
CollectCoefficients = enum.auto()
LogicEvaluation = enum.auto()

ALL = Flatten | IntegerArithmetic | CollectCoefficients # pylint: disable=unsupported-binary-operation
# pylint: disable-next=unsupported-binary-operation
ALL = Flatten | IntegerArithmetic | CollectCoefficients | LogicEvaluation


class SimplifyMapper(LokiIdentityMapper):
Expand Down Expand Up @@ -566,6 +569,36 @@ def map_quotient(self, expr, *args, **kwargs):
map_parenthesised_mul = map_product
map_parenthesised_div = map_quotient

def map_comparison(self, expr, *args, **kwargs):
left = self.rec(expr.left, *args, **kwargs)
right = self.rec(expr.right, *args, **kwargs)

if self.enabled_simplifications & Simplification.LogicEvaluation:
if is_constant(left) and is_constant(right):
if expr.operator == '==' and left == right:
return sym.LogicLiteral('True')
return sym.LogicLiteral('False')

return sym.Comparison(operator=expr.operator, left=left, right=right)

def map_logical_and(self, expr, *args, **kwargs):
children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
if all(isinstance(c, sym.LogicLiteral) for c in children):
if all(c == 'True' for c in children):
return sym.LogicLiteral('True')
return sym.LogicLiteral('False')

return sym.LogicalAnd(children)

def map_logical_or(self, expr, *args, **kwargs):
children = tuple(self.rec(child, *args, **kwargs) for child in expr.children)
if all(isinstance(c, sym.LogicLiteral) for c in children):
if any(c == 'True' for c in children):
return sym.LogicLiteral('True')
return sym.LogicLiteral('False')

return sym.LogicalOr(children)


def simplify(expr, enabled_simplifications=Simplification.ALL):
"""
Expand Down
1 change: 1 addition & 0 deletions loki/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@
from loki.transform.transform_parametrise import * # noqa
from loki.transform.transform_extract_contained_procedures import * # noqa
from loki.transform.transform_sequence_association import * # noqa
from loki.transform.transform_dead_code import * # noqa
68 changes: 68 additions & 0 deletions loki/transform/transform_dead_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utilities to perform Dead Code Elimination.
"""
from loki.visitors import Transformer
from loki.expression.symbolic import simplify
from loki.tools import flatten, as_tuple
from loki.ir import Conditional


__all__ = ['dead_code_elimination', 'DeadCodeEliminationTransformer']


def dead_code_elimination(routine, use_simplify=True):
"""
Perform Dead Code Elimination on the given :any:`Subroutine` object.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to which to apply dead code elimination.
simplify : boolean
Use :any:`simplify` when evaluating expressions for branch pruning.
"""

transformer = DeadCodeEliminationTransformer(use_simplify=use_simplify)
routine.body = transformer.visit(routine.body)


class DeadCodeEliminationTransformer(Transformer):
"""
:any:`Transformer` class that removes provably unreachable code paths.
The pirmary modification performed is to prune individual code branches
under :any:`Conditional` nodes.
Parameters
----------
simplify : boolean
Use :any:`simplify` when evaluating expressions for branch pruning.
"""

def __init__(self, use_simplify=True, **kwargs):
super().__init__(**kwargs)
self.use_simplify = use_simplify

def visit_Conditional(self, o, **kwargs):
condition = self.visit(o.condition, **kwargs)
body = as_tuple(flatten(as_tuple(self.visit(o.body, **kwargs))))
else_body = as_tuple(flatten(as_tuple(self.visit(o.else_body, **kwargs))))

if self.use_simplify:
condition = simplify(condition)

if condition == 'True':
return body

if condition == 'False':
return else_body

has_elseif = o.has_elseif and else_body and isinstance(else_body[0], Conditional)
return self._rebuild(o, tuple((condition,) + (body,) + (else_body,)), has_elseif=has_elseif)
Loading

0 comments on commit 9f25284

Please sign in to comment.