Skip to content

Commit

Permalink
Merge pull request #212 from ecmwf-ifs/nabr-minor-frontend-fixes
Browse files Browse the repository at this point in the history
Minor fixes to frontends and IR nodes
  • Loading branch information
reuterbal authored Jan 26, 2024
2 parents 903798f + 15f0de3 commit 13739a1
Show file tree
Hide file tree
Showing 19 changed files with 454 additions and 64 deletions.
8 changes: 4 additions & 4 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -916,10 +916,10 @@ def visit_Procedure_Declaration_Stmt(self, o, **kwargs):
interface = return_type.dtype
_type = SymbolAttributes(BasicType.DEFERRED, **attrs)

# Make sure any "initial" symbol (i.e. the procedure we're binding to) is in the right scope
if _type.initial is not None:
initial = AttachScopesMapper()(_type.initial, scope=scope)
_type = _type.clone(initial=initial)
# Make sure any "bind_names" symbol (i.e. the procedure we're binding to) is in the right scope
if _type.bind_names is not None:
bind_names = AttachScopesMapper()(_type.bind_names, scope=scope)
_type = _type.clone(bind_names=bind_names)

# Update symbol table entries
if return_type is None:
Expand Down
18 changes: 14 additions & 4 deletions loki/frontend/ofp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,12 +1083,20 @@ def visit_interface(self, o, **kwargs):
spec = spec.rescope(scope=scope)

body = []
# Sometimes, OFP gets it wrong and puts the declarations inside <declaration>,
# on other occasions it puts it inside <specification>... ¯\_(ツ)_/¯
grouped_elems = match_tag_sequence(o.find('body/specification/declaration'), [
('names', 'procedure-stmt'),
('function', ),
('subroutine', ),
('comment', ),
])
if not grouped_elems:
grouped_elems = match_tag_sequence(o.find('body/specification'), [
('function',),
('subroutine',),
('comment',)
])

for group in grouped_elems:
if len(group) == 1:
Expand Down Expand Up @@ -1365,7 +1373,7 @@ def visit_deallocate(self, o, **kwargs):
source=kwargs['source'], status_var=kw_args.get('stat'))

def visit_use(self, o, **kwargs):
name, module = self.visit(o.find('use-stmt'), **kwargs)
name, module, nature = self.visit(o.find('use-stmt'), **kwargs)
scope = kwargs['scope']
if o.find('only') is not None:
# ONLY list given (import only selected symbols)
Expand Down Expand Up @@ -1422,7 +1430,7 @@ def visit_use(self, o, **kwargs):
})
rename_list = tuple(rename_list.items()) if rename_list else None
return ir.Import(module=name, symbols=symbols, rename_list=rename_list,
label=kwargs['label'], source=kwargs['source'])
nature=nature, label=kwargs['label'], source=kwargs['source'])

def visit_only(self, o, **kwargs):
count = int(o.find('only-list').get('count'))
Expand All @@ -1438,8 +1446,10 @@ def visit_rename(self, o, **kwargs):
def visit_use_stmt(self, o, **kwargs):
name = o.attrib['id']
if o.attrib['hasModuleNature'] != 'false':
self.warn_or_fail('module nature in USE statement not implemented')
return name, self.definitions.get(name)
self.warn_or_fail('Module nature in USE statement not implemented. Assuming INTRINSIC')
# Do not return module reference for intrinsic modules
return name, None, 'intrinsic'
return name, self.definitions.get(name), None

def visit_directive(self, o, **kwargs):
source = kwargs['source']
Expand Down
51 changes: 30 additions & 21 deletions loki/frontend/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,14 @@ class RegexParserClass(Flag):
pattern matching can be switched on and off for some pattern classes, and thus the overall
parse time reduced.
"""
EmptyClass = 0
ProgramUnitClass = auto()
InterfaceClass = auto()
ImportClass = auto()
TypeDefClass = auto()
DeclarationClass = auto()
CallClass = auto()
AllClasses = ProgramUnitClass | ImportClass | TypeDefClass | DeclarationClass | CallClass # pylint: disable=unsupported-binary-operation
AllClasses = ProgramUnitClass | InterfaceClass | ImportClass | TypeDefClass | DeclarationClass | CallClass # pylint: disable=unsupported-binary-operation


class Pattern:
Expand Down Expand Up @@ -182,22 +184,21 @@ def match_statement_candidates(cls, reader, candidates, parser_classes=None, sco
filtered_candidates = [
candidate for candidate in filtered_candidates if candidate.parser_class & parser_classes
]
if not filtered_candidates:
return []

ir_ = []
last_match = -1
for idx, _ in enumerate(reader):
for candidate in filtered_candidates:
match = candidate.match(reader, parser_classes=parser_classes, scope=scope)
if match:
if last_match - idx > 1:
span = (reader.sanitized_spans[last_match + 1], reader.sanitized_spans[idx])
source = reader.source_from_sanitized_span(span)
ir_ += [ir.RawSource(source.string, source=source)]
last_match = idx
ir_ += [match]
break
if filtered_candidates:
for idx, _ in enumerate(reader):
for candidate in filtered_candidates:
match = candidate.match(reader, parser_classes=parser_classes, scope=scope)
if match:
if last_match - idx > 1:
span = (reader.sanitized_spans[last_match + 1], reader.sanitized_spans[idx])
source = reader.source_from_sanitized_span(span)
ir_ += [ir.RawSource(source.string, source=source)]
last_match = idx
ir_ += [match]
break

if head is not None and ir_:
ir_ = [ir.RawSource(text=head.string, source=head)] + ir_
Expand Down Expand Up @@ -446,7 +447,8 @@ def match(self, reader, parser_classes, scope):
contains = None

module.__initialize__( # pylint: disable=unnecessary-dunder-call
name=module.name, spec=spec, contains=contains, source=module.source, incomplete=True
name=module.name, spec=spec, contains=contains, source=module.source, incomplete=True,
parser_classes=parser_classes
)

if match.span()[0] > 0:
Expand Down Expand Up @@ -537,7 +539,8 @@ def match(self, reader, parser_classes, scope):

routine.__initialize__( # pylint: disable=unnecessary-dunder-call
name=routine.name, args=routine._dummies, is_function=routine.is_function,
prefix=prefix, spec=spec, contains=contains, source=routine.source, incomplete=True
prefix=prefix, spec=spec, contains=contains, source=routine.source,
incomplete=True, parser_classes=parser_classes
)

if match.span()[0] > 0:
Expand All @@ -552,7 +555,7 @@ class InterfacePattern(Pattern):
Pattern to match :any:`Interface` objects
"""

parser_class = RegexParserClass.ProgramUnitClass
parser_class = RegexParserClass.InterfaceClass

def __init__(self):
super().__init__(
Expand Down Expand Up @@ -611,7 +614,7 @@ class ProcedureStatementPattern(Pattern):
Pattern to match procedure statements in interfaces
"""

parser_class = RegexParserClass.ProgramUnitClass
parser_class = RegexParserClass.InterfaceClass

def __init__(self):
super().__init__(
Expand Down Expand Up @@ -772,8 +775,8 @@ def match(self, reader, parser_classes, scope):
symbols += [sym.Variable(name=s[0], type=type_, scope=scope)]
else:
type_ = SymbolAttributes(ProcedureType(name=s[1]))
initial = sym.Variable(name=s[1], type=type_, scope=scope.parent)
symbols += [sym.Variable(name=s[0], type=type_.clone(initial=initial), scope=scope)]
bind_name = sym.Variable(name=s[1], type=type_, scope=scope.parent)
symbols += [sym.Variable(name=s[0], type=type_.clone(bind_names=(bind_name,)), scope=scope)]

return ir.ProcedureDeclaration(symbols=symbols, source=reader.source_from_current_line())

Expand Down Expand Up @@ -896,7 +899,8 @@ class VariableDeclarationPattern(Pattern):
def __init__(self):
super().__init__(
r'^(((?:type|class)[ \t]*\([ \t]*(?P<typename>\w+)[ \t]*\))|' # TYPE or CLASS keyword with typename
r'^([ \t]*(?P<basic_type>(logical|real|integer|complex|character))(\((kind|len)=[a-z0-9_-]+\))?[ \t]*))'
r'^([ \t]*(?P<basic_type>(logical|real|integer|complex|character))'
r'(?P<param>\((kind|len)=[a-z0-9_-]+\))?[ \t]*))'
r'(?:[ \t]*,[ \t]*[a-z]+(?:\((.(\(.*\))?)*?\))?)*' # Optional attributes
r'(?:[ \t]*::)?' # Optional `::` delimiter
r'[ \t]*' # Some white space
Expand Down Expand Up @@ -928,6 +932,11 @@ def match(self, reader, parser_classes, scope):
type_ = SymbolAttributes(BasicType.from_str(match['basic_type']))
assert type_

if match['param']:
param = match['param'].strip().strip('()').split('=')
if len(param) == 1 or param[0].lower() == 'kind':
type_ = type_.clone(kind=sym.Variable(name=param[-1], scope=scope))

variables = self._remove_quoted_string_nested_parentheses(match['variables']) # Remove dimensions
variables = re.sub(r'=(?:>)?[^,]*(?=,|$)', r'', variables) # Remove initialization
variables = variables.replace(' ', '').split(',') # Variable names without white space
Expand Down
28 changes: 28 additions & 0 deletions loki/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,18 @@ def symbols(self):
return (self.spec,) + symbols
return symbols

@property
def symbol_map(self):
"""
Map symbol name to symbol declared by this interface
"""
return CaseInsensitiveDict(
(s.name.lower(), s) for s in self.symbols
)

def __contains__(self, name):
return name in self.symbol_map

def __repr__(self):
symbols = ', '.join(str(var) for var in self.symbols)
if self.abstract:
Expand Down Expand Up @@ -1515,6 +1527,22 @@ def imported_symbol_map(self):
"""
return CaseInsensitiveDict((s.name, s) for s in self.imported_symbols)

def __contains__(self, name):
"""
Check if a symbol with the given name is declared in this type
"""
return name in self.variables

@property
def interface_symbols(self):
"""
Return the list of symbols declared via interfaces in this unit
This returns always an empty tuple since there are no interface declarations
allowed in typedefs.
"""
return ()

@property
def dtype(self):
"""
Expand Down
4 changes: 3 additions & 1 deletion loki/lint/linter.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from codetiming import Timer

from loki.build import workqueue
from loki.bulk import Scheduler, SchedulerConfig
from loki.bulk import Scheduler, SchedulerConfig, Item
from loki.config import config as loki_config
from loki.lint.reporter import (
FileReport, RuleReport, Reporter, LazyTextfile,
Expand Down Expand Up @@ -251,6 +251,8 @@ class LinterTransformation(Transformation):
# This transformation is applied over the file graph
traverse_file_graph = True

item_filter = Item # Include everything in the dependency tree

def __init__(self, linter, key=None, **kwargs):
self.linter = linter
self.counter = 0
Expand Down
12 changes: 7 additions & 5 deletions loki/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ class Module(ProgramUnit):
Mark the object as incomplete, i.e. only partially parsed. This is
typically the case when it was instantiated using the :any:`Frontend.REGEX`
frontend and a full parse using one of the other frontends is pending.
parser_classes : :any:`RegexParserClass`, optional
Provide the list of parser classes used during incomplete regex parsing
"""

def __init__(
self, name=None, docstring=None, spec=None, contains=None,
default_access_spec=None, public_access_spec=None, private_access_spec=None,
ast=None, source=None, parent=None, symbol_attrs=None, rescope_symbols=False,
incomplete=False
incomplete=False, parser_classes=None
):
super().__init__(parent=parent)

Expand All @@ -84,12 +86,12 @@ def __init__(
name=name, docstring=docstring, spec=spec, contains=contains,
default_access_spec=default_access_spec, public_access_spec=public_access_spec,
private_access_spec=private_access_spec, ast=ast, source=source,
rescope_symbols=rescope_symbols, incomplete=incomplete
rescope_symbols=rescope_symbols, incomplete=incomplete, parser_classes=parser_classes
)

def __initialize__(
self, name=None, docstring=None, spec=None, contains=None,
ast=None, source=None, rescope_symbols=False, incomplete=False,
ast=None, source=None, rescope_symbols=False, incomplete=False, parser_classes=None,
default_access_spec=None, public_access_spec=None, private_access_spec=None
):
# Apply dimension pragma annotations to declarations
Expand All @@ -110,7 +112,7 @@ def __initialize__(

super().__initialize__(
name=name, docstring=docstring, spec=spec, contains=contains, ast=ast,
source=source, rescope_symbols=rescope_symbols, incomplete=incomplete
source=source, rescope_symbols=rescope_symbols, incomplete=incomplete, parser_classes=parser_classes
)

@classmethod
Expand Down Expand Up @@ -299,4 +301,4 @@ def definitions(self):
Returns :any:`Subroutine` and :any:`TypeDef` nodes declared
in this module
"""
return self.subroutines + self.typedefs + self.variables
return self.subroutines + self.typedefs + self.variables + self.interfaces
42 changes: 39 additions & 3 deletions loki/program_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from loki import ir
from loki.expression import Variable
from loki.frontend import Frontend, parse_omni_source, parse_ofp_source, parse_fparser_source
from loki.frontend import (
Frontend, parse_omni_source, parse_ofp_source, parse_fparser_source,
RegexParserClass
)
from loki.logging import debug
from loki.scope import Scope
from loki.tools import CaseInsensitiveDict, as_tuple, flatten
Expand Down Expand Up @@ -54,16 +57,20 @@ class ProgramUnit(Scope):
Mark the object as incomplete, i.e. only partially parsed. This is
typically the case when it was instantiated using the :any:`Frontend.REGEX`
frontend and a full parse using one of the other frontends is pending.
parser_classes : :any:`RegexParserClass`, optional
Provide the list of parser classes used during incomplete regex parsing
"""

def __initialize__(self, name, docstring=None, spec=None, contains=None,
ast=None, source=None, rescope_symbols=False, incomplete=False):
ast=None, source=None, rescope_symbols=False, incomplete=False,
parser_classes=None):
# Common properties
assert name and isinstance(name, str)
self.name = name
self._ast = ast
self._source = source
self._incomplete = incomplete
self._parser_classes = parser_classes

# Bring arguments into shape
if spec is not None and not isinstance(spec, ir.Section):
Expand Down Expand Up @@ -237,7 +244,11 @@ def make_complete(self, **frontend_args):
frontend = frontend_args.pop('frontend', Frontend.FP)
definitions = frontend_args.get('definitions')
xmods = frontend_args.get('xmods')
parser_classes = frontend_args.get('parser_classes')
parser_classes = frontend_args.get('parser_classes', RegexParserClass.AllClasses)
if frontend == Frontend.REGEX and self._parser_classes:
if self._parser_classes == parser_classes:
return
parser_classes = parser_classes | self._parser_classes

# If this object does not have a parent, we create a temporary parent scope
# and make sure the node exists in the parent scope. This way, the existing
Expand Down Expand Up @@ -319,6 +330,11 @@ def enrich(self, definitions, recurse=False):
debug('Cannot enrich import of %s from module %s', local_name, module.name)
self.symbol_attrs.update(updated_symbol_attrs)

if imprt.symbols:
# Rebuild the symbols in the import's symbol list to obtain the correct
# expression nodes
imprt._update(symbols=tuple(symbol.clone() for symbol in imprt.symbols))

# Update any symbol table entries that have been inherited from the parent
if self.parent:
updated_symbol_attrs = {}
Expand Down Expand Up @@ -506,6 +522,17 @@ def imported_symbol_map(self):
"""
return CaseInsensitiveDict((s.name, s) for s in self.imported_symbols)

@property
def all_imports(self):
"""
Return the list of :any:`Import` in this unit and any parent scopes
"""
imports = self.imports
scope = self
while (scope := scope.parent):
imports += scope.imports
return imports

@property
def interfaces(self):
"""
Expand Down Expand Up @@ -545,6 +572,15 @@ def enum_symbols(self):
"""
return as_tuple(flatten(enum.symbols for enum in FindNodes(ir.Enumeration).visit(self.spec or ())))

@property
def definitions(self):
"""
The list of IR nodes defined by this program unit.
Returns an empty tuple by default and can be overwritten by derived nodes.
"""
return ()

@property
def symbols(self):
"""
Expand Down
Loading

0 comments on commit 13739a1

Please sign in to comment.