diff --git a/loki/backend/fgen.py b/loki/backend/fgen.py index 2ae8fb972..486caf319 100644 --- a/loki/backend/fgen.py +++ b/loki/backend/fgen.py @@ -328,20 +328,24 @@ def visit_VariableDeclaration(self, o, **kwargs): # the symbol has a known derived type ignore = ['shape', 'dimensions', 'variables', 'source', 'initial'] - if isinstance(types[0].dtype, ProcedureType): + # Statement functions can share declarations with scalars, so we collect the variable types here + _var_types = [t.dtype.return_type.dtype if isinstance(t.dtype, ProcedureType) else t.dtype for t in types] + _procedure_types = [t for t in types if isinstance(t, ProcedureType)] + + if len(_procedure_types) > 0: # Statement functions are the only symbol with ProcedureType that should appear # in a VariableDeclaration as all other forms of procedure declarations (bindings, # pointers, EXTERNAL statements) are handled by ProcedureDeclaration. # However, the fact that statement function declarations can appear mixed with actual # variable declarations forbids this in this case. - assert types[0].is_stmt_func + assert _procedure_types[0].is_stmt_func # TODO: We can't fully compare statement functions, yet but we can make at least sure # other declared attributes are compatible and that all have the same return type ignore += ['dtype'] assert all(t.dtype.return_type == types[0].dtype.return_type or - t.dtype.return_type.compare(types[0].dtype.return_type, ignore=ignore) for t in types) + t.dtype.return_type.compare(types[0].dtype.return_type, ignore=ignore) for t in _procedure_types) - assert all(t.compare(types[0], ignore=ignore) for t in types) + assert all((t == _var_types[0]) for t in _var_types) is_function = isinstance(types[0].dtype, ProcedureType) and types[0].dtype.is_function if is_function: diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 6fd54b08c..dba5911d2 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -1812,11 +1812,12 @@ def visit_Subroutine_Subprogram(self, o, **kwargs): rescope_symbols=True, source=source, incomplete=False ) - # Once statement functions are in place, we need to update the original declaration symbol + # Once statement functions are in place, we need to update the original declaration so that it + # contains ProcedureSymbols rather than Scalars for decl in FindNodes(ir.VariableDeclaration).visit(spec): if any(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols): - assert all(routine.symbol_attrs[s.name].is_stmt_func for s in decl.symbols) - decl._update(symbols=tuple(s.clone() for s in decl.symbols)) + decl._update(symbols=tuple(s.clone() if routine.symbol_attrs[s.name].is_stmt_func else s + for s in decl.symbols)) # Big, but necessary hack: # For deferred array dimensions on allocatables, we infer the conceptual diff --git a/tests/test_subroutine.py b/tests/test_subroutine.py index f6df24a24..8c4817d06 100644 --- a/tests/test_subroutine.py +++ b/tests/test_subroutine.py @@ -1493,8 +1493,7 @@ def test_subroutine_stmt_func(here, frontend): integer, intent(in) :: a integer, intent(out) :: b integer :: array(a) - integer :: i, j - integer :: plus, minus + integer :: i, j, plus, minus plus(i, j) = i + j minus(i, j) = i - j integer :: mult