diff --git a/loki/transform/transform_parametrise.py b/loki/transform/transform_parametrise.py index 78295f9ca..19907ca75 100644 --- a/loki/transform/transform_parametrise.py +++ b/loki/transform/transform_parametrise.py @@ -84,7 +84,7 @@ from loki.expression import symbols as sym from loki import ir from loki.visitors import Transformer, FindNodes -from loki.tools.util import is_iterable, as_tuple +from loki.tools.util import as_tuple, CaseInsensitiveDict from loki.transform.transformation import Transformation from loki.transform.transform_inline import inline_constant_parameters @@ -121,7 +121,7 @@ def error_stop(**kwargs): dic2p = {'a': 12, 'b': 11} - transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop, + transformation = ParametriseTransformation(dic2p=dic2p, abort_callback=error_stop, entry_points=("driver1", "driver2")) scheduler.process(transformation=transformation) @@ -153,12 +153,7 @@ def error_stop(**kwargs): def __init__(self, dic2p, replace_by_value=False, entry_points=None, abort_callback=None, key=None): self.dic2p = dic2p self.replace_by_value = replace_by_value - if entry_points is not None: - self.entry_points = [_.upper() for _ in entry_points] - else: - self.entry_points = entry_points - if self.entry_points is not None: - assert is_iterable(entry_points) + self.entry_points = tuple(entry_point.upper() for entry_point in as_tuple(entry_points)) or None self.abort_callback = abort_callback if key is not None: self._key = key @@ -181,9 +176,10 @@ def transform_subroutine(self, routine, **kwargs): item = kwargs.get('item', None) role = kwargs.get('role', None) - _successors = kwargs.get('successors', None) - successor_map = {successor.routine.name: successor for successor in _successors} - successors = [successor.local_name.upper() for successor in _successors] + successor_map = CaseInsensitiveDict( + (successor.local_name, successor) + for successor in as_tuple(kwargs.get('successors')) + ) # decide whether subroutine is an entry point or not process_entry_point = False @@ -248,7 +244,7 @@ def transform_subroutine(self, routine, **kwargs): # remove variables to be parametrised from all call statements call_map = {} for call in FindNodes(ir.CallStatement).visit(routine.body): - if str(call.name).upper() in successors: + if str(call.name) in successor_map: successor_map[str(call.name)].trafo_data[self._key] = {} arg_map = dict(call.arg_iter()) arg_map_reversed = {v: k for k, v in arg_map.items()}