diff --git a/loki/backend/fgen.py b/loki/backend/fgen.py index 2ae8fb972..2576c8ca3 100644 --- a/loki/backend/fgen.py +++ b/loki/backend/fgen.py @@ -328,20 +328,25 @@ def visit_VariableDeclaration(self, o, **kwargs): # the symbol has a known derived type ignore = ['shape', 'dimensions', 'variables', 'source', 'initial'] - if isinstance(types[0].dtype, ProcedureType): + # Statement functions can share declarations with scalars, so we collect the variable types here + _var_types = [t.dtype.return_type.dtype if isinstance(t.dtype, ProcedureType) else t.dtype for t in types] + _procedure_types = [t for t in types if isinstance(t.dtype, ProcedureType)] + + if _procedure_types: # Statement functions are the only symbol with ProcedureType that should appear # in a VariableDeclaration as all other forms of procedure declarations (bindings, # pointers, EXTERNAL statements) are handled by ProcedureDeclaration. # However, the fact that statement function declarations can appear mixed with actual # variable declarations forbids this in this case. - assert types[0].is_stmt_func + assert all(t.is_stmt_func for t in _procedure_types) # TODO: We can't fully compare statement functions, yet but we can make at least sure # other declared attributes are compatible and that all have the same return type ignore += ['dtype'] - assert all(t.dtype.return_type == types[0].dtype.return_type or - t.dtype.return_type.compare(types[0].dtype.return_type, ignore=ignore) for t in types) + assert all(t.dtype.return_type == _procedure_types[0].dtype.return_type or + t.dtype.return_type.compare(_procedure_types[0].dtype.return_type, ignore=ignore) + for t in _procedure_types) - assert all(t.compare(types[0], ignore=ignore) for t in types) + assert all((t == _var_types[0]) for t in _var_types) is_function = isinstance(types[0].dtype, ProcedureType) and types[0].dtype.is_function if is_function: diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 6fd54b08c..dba5911d2 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -1812,11 +1812,12 @@ def visit_Subroutine_Subprogram(self, o, **kwargs): rescope_symbols=True, source=source, incomplete=False ) - # Once statement functions are in place, we need to update the original declaration symbol + # Once statement functions are in place, we need to update the original declaration so that it + # contains ProcedureSymbols rather than Scalars for decl in FindNodes(ir.VariableDeclaration).visit(spec): if any(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols): - assert all(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols) - decl._update(symbols=tuple(s.clone() for s in decl.symbols)) + decl._update(symbols=tuple(s.clone() if routine.symbol_attrs[s.name].is_stmt_func else s + for s in decl.symbols)) # Big, but necessary hack: # For deferred array dimensions on allocatables, we infer the conceptual diff --git a/loki/frontend/regex.py b/loki/frontend/regex.py index 888cbaa8e..bd3fa79fc 100644 --- a/loki/frontend/regex.py +++ b/loki/frontend/regex.py @@ -468,8 +468,8 @@ def __init__(self): r'^(?P[ \t\w()=]*)?(?Psubroutine|function)[ \t]+(?P\w+)\b.*?$' r'(?P(?:.*?(?:^(?:abstract[ \t]+)?interface\b.*?^end[ \t]+interface)?)+)' r'(?P^contains\n(?:' - r'(?:[ \t\w()]*?subroutine.*?^end[ \t]*subroutine\b(?:[ \t]\w+)?\n)|' - r'(?:[ \t\w()]*?function.*?^end[ \t]*function\b(?:[ \t]\w+)?\n)|' + r'(?:[ \t\w()=]*?subroutine.*?^end[ \t]*subroutine\b(?:[ \t]\w+)?\n)|' + r'(?:[ \t\w()=]*?function.*?^end[ \t]*function\b(?:[ \t]\w+)?\n)|' r'(?:^#\w+.*?\n)' r')*)?' r'^end[ \t]*(?P=keyword)\b(?:[ \t](?P=name))?', diff --git a/tests/test_frontends.py b/tests/test_frontends.py index ff6ce9d59..353ecb4bf 100644 --- a/tests/test_frontends.py +++ b/tests/test_frontends.py @@ -381,6 +381,7 @@ def test_regex_subroutine_from_source(): ! arg2 j ) + use parkind1, only : jpim implicit none integer, intent(in) :: i, j integer b @@ -391,16 +392,17 @@ def test_regex_subroutine_from_source(): call routine_a() contains !abc ^$^** + integer(kind=jpim) function contained_e(i) + integer, intent(in) :: i + contained_e = i + end function + subroutine contained_c(i) integer, intent(in) :: i integer c c = 5 end subroutine contained_c ! cc£$^£$^ - integer function contained_e(i) - integer, intent(in) :: i - contained_e = i - end function subroutine contained_d(i) integer, intent(in) :: i @@ -415,7 +417,7 @@ def test_regex_subroutine_from_source(): assert not routine.is_function assert routine.arguments == () assert routine.argnames == [] - assert [r.name for r in routine.subroutines] == ['contained_c', 'contained_e', 'contained_d'] + assert [r.name for r in routine.subroutines] == ['contained_e', 'contained_c', 'contained_d'] contained_c = routine['contained_c'] assert contained_c.name == 'contained_c' @@ -1487,7 +1489,7 @@ def test_regex_fypp(): @pytest.mark.parametrize( - 'frontend', + 'frontend', available_frontends(include_regex=True, xfail=[(OMNI, 'OMNI may segfault on empty files')]) ) @pytest.mark.parametrize('fcode', ['', '\n', '\n\n\n\n']) diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index f6df24a24..8c4817d06 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -1493,8 +1493,7 @@ def test_subroutine_stmt_func(here, frontend): integer, intent(in) :: a integer, intent(out) :: b integer :: array(a) - integer :: i, j - integer :: plus, minus + integer :: i, j, plus, minus plus(i, j) = i + j minus(i, j) = i - j integer :: mult