From 0bb0a7b89720d468a350f982b62bf9a2fcb14113 Mon Sep 17 00:00:00 2001 From: Balthasar Reuter Date: Sun, 26 Mar 2023 22:41:53 +0100 Subject: [PATCH] Fix infinite recursion on abstract procedure declaration --- loki/expression/mappers.py | 28 +++++++++++++----- loki/frontend/ofp.py | 2 +- loki/ir.py | 10 ++++++- loki/transform/transform_associates.py | 2 +- tests/test_derived_types.py | 39 ++++++++++++++++++++++++++ 5 files changed, 71 insertions(+), 10 deletions(-) diff --git a/loki/expression/mappers.py b/loki/expression/mappers.py index 5c435fb3d..900bb017e 100644 --- a/loki/expression/mappers.py +++ b/loki/expression/mappers.py @@ -23,7 +23,7 @@ except ImportError: _intrinsic_fortran_names = () -from loki.ir import VariableDeclaration, Import +from loki.ir import DECLARATION_NODES from loki.logging import debug from loki.tools import as_tuple, flatten from loki.types import SymbolAttributes, BasicType @@ -523,7 +523,7 @@ def __call__(self, expr, *args, **kwargs): return None kwargs.setdefault( 'recurse_to_declaration_attributes', - 'current_node' not in kwargs or isinstance(kwargs['current_node'], (VariableDeclaration, Import)) + 'current_node' not in kwargs or isinstance(kwargs['current_node'], DECLARATION_NODES) ) new_expr = super().__call__(expr, *args, **kwargs) if getattr(expr, 'source', None): @@ -583,11 +583,25 @@ def map_variable_symbol(self, expr, *args, **kwargs): # it does not affect the outcome of expr.clone expr.scope.symbol_attrs[expr.name] = expr.type.clone(initial=initial) - bind_names = self.rec(expr.type.bind_names, *args, **kwargs) - if not (bind_names is None or all(new is old for new, old in zip_longest(bind_names, expr.type.bind_names))): - # Update symbol table entry for bind_names directly because with a scope attached - # it does not affect the outcome of expr.clone - expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=as_tuple(bind_names)) + if kwargs['recurse_to_declaration_attributes']: + _kwargs = kwargs.copy() + _kwargs['recurse_to_declaration_attributes'] = False + if (old_bind_names := expr.type.bind_names): + bind_names = () + for bind_name in old_bind_names: + if bind_name == expr.name: + # FIXME: This is a hack to work around situations where an + # explicit interface is used with the same name as the + # type bound procedure. This hands down the correct scope. + __kwargs = _kwargs.copy() + __kwargs['scope'] = expr.scope.parent + bind_names += (self.rec(bind_name, *args, **__kwargs),) + else: + bind_names += (self.rec(bind_name, *args, **_kwargs),) + if bind_names and any(new is not old for new, old in zip_longest(bind_names, expr.type.bind_names)): + # Update symbol table entry for bind_names directly because with a scope attached + # it does not affect the outcome of expr.clone + expr.scope.symbol_attrs[expr.name] = expr.type.clone(bind_names=bind_names) parent = self.rec(expr.parent, *args, **kwargs) if parent is expr.parent and (kind is expr.type.kind or expr.scope): diff --git a/loki/frontend/ofp.py b/loki/frontend/ofp.py index 742b95b16..25d5aa39c 100644 --- a/loki/frontend/ofp.py +++ b/loki/frontend/ofp.py @@ -609,7 +609,7 @@ def visit_specific_binding(self, o, **kwargs): type=SymbolAttributes(ProcedureType(interface)) ) - _type = interface.type + _type = interface.type.clone(bind_names=(interface,)) elif o.attrib['procedureName']: # Binding provided ( => ) diff --git a/loki/ir.py b/loki/ir.py index d1680a2cc..be93a09a6 100644 --- a/loki/ir.py +++ b/loki/ir.py @@ -39,7 +39,9 @@ 'Comment', 'CommentBlock', 'Pragma', 'PreprocessorDirective', 'Import', 'VariableDeclaration', 'ProcedureDeclaration', 'DataDeclaration', 'StatementFunction', 'TypeDef', 'MultiConditional', 'MaskedStatement', - 'Intrinsic', 'Enumeration', 'RawSource' + 'Intrinsic', 'Enumeration', 'RawSource', + # List of nodes with specific properties + 'DECLARATION_NODES' ] # Configuration for validation mechanism via pydantic @@ -1821,3 +1823,9 @@ class RawSource(LeafNode, _RawSourceBase): def __repr__(self): return f'RawSource:: {truncate_string(self.text.strip())}' + + +DECLARATION_NODES = (Import, VariableDeclaration, ProcedureDeclaration) +""" +List of IR nodes that are considered to be the authority on a symbol's attributes +""" diff --git a/loki/transform/transform_associates.py b/loki/transform/transform_associates.py index 987811fd6..706be7b34 100644 --- a/loki/transform/transform_associates.py +++ b/loki/transform/transform_associates.py @@ -40,7 +40,7 @@ class ResolveAssociatesTransformer(Transformer): corresponding `selector` expression defined in ``associations``. """ - def visit_Associate(self, o, **kwargs): + def visit_Associate(self, o, **kwargs): # pylint: disable=unused-argument # First head-recurse, so that all associate blocks beneath are resolved body = self.visit(o.body) diff --git a/tests/test_derived_types.py b/tests/test_derived_types.py index c4dcdb67c..046c09bf8 100644 --- a/tests/test_derived_types.py +++ b/tests/test_derived_types.py @@ -1355,3 +1355,42 @@ def test_derived_types_nested_type(frontend): assert assignment.rhs.parent.type.dtype.typedef is module['some_type'] assert assignment.rhs.parent.parent.type.dtype.name == 'other_type' assert assignment.rhs.parent.parent.type.dtype.typedef is module['other_type'] + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_derived_types_abstract_deferred_procedure(frontend): + fcode = """ +module some_mod + implicit none + type, abstract :: abstract_type + contains + procedure (some_proc), deferred :: some_proc + procedure (other_proc), deferred :: other_proc + end type abstract_type + + abstract interface + subroutine some_proc(this) + import abstract_type + class(abstract_type), intent(in) :: this + end subroutine some_proc + end interface + + abstract interface + subroutine other_proc(this) + import abstract_type + class(abstract_type), intent(inout) :: this + end subroutine other_proc + end interface +end module some_mod + """.strip() + module = Module.from_source(fcode, frontend=frontend) + typedef = module['abstract_type'] + assert typedef.abstract is True + assert typedef.variables == ('some_proc', 'other_proc') + for symbol in typedef.variables: + assert isinstance(symbol, ProcedureSymbol) + assert isinstance(symbol.type.dtype, ProcedureType) + assert symbol.type.dtype.name.lower() == symbol.name.lower() + assert symbol.type.bind_names == (symbol,) + assert symbol.scope is typedef + assert symbol.type.bind_names[0].scope is module