From 71b4bddacfe0870bd87d7bc6b4e07c4433434a7e Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 9 Feb 2024 12:35:49 +0000 Subject: [PATCH 1/6] Frontend: Add perf-level timers around sanitize_ir --- loki/frontend/util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index fd68a1d13..62578dc46 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -8,6 +8,7 @@ from enum import IntEnum from pathlib import Path import codecs +from codetiming import Timer from loki.visitors import NestedTransformer, FindNodes, PatternFinder, SequenceFinder from loki.ir import ( @@ -15,7 +16,7 @@ Loop, Intrinsic, Pragma ) from loki.frontend.source import Source -from loki.logging import warning +from loki.logging import warning, perf __all__ = [ 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', @@ -180,6 +181,7 @@ def combine_multiline_pragmas(ir): return NestedTransformer(pragma_mapper, invalidate_source=False).visit(ir) +@Timer(logger=perf, text=lambda s: f'[Loki::Frontend] Executed sanitize_ir in {s:.2f}s') def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): """ Utility function to sanitize internal representation after creating it From 4a22a0c07070e8ec709b4772d7af08a19e5a8dc7 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 9 Feb 2024 05:23:54 +0000 Subject: [PATCH 2/6] Visitors: Better implementation of multi-node key tuple-injection This now uses `more_itertools.replace` and relies on natural recursion. --- loki/visitors/transform.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index 5505596c1..26290ad0b 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -8,7 +8,7 @@ """ Visitor classes for transforming the IR """ -from more_itertools import windowed +from more_itertools import replace from loki.ir import Node, Conditional, ScopedNode from loki.tools import flatten, is_iterable, as_tuple @@ -137,10 +137,11 @@ def _inject_handle(nodes, i, old, new): for k, handle in self.mapper.items(): if is_iterable(k): - w = list(windowed(o, len(k))) - if k in w: - i = list(w).index(k) - o = o[:i] + as_tuple(handle) + o[i+len(k):] + k = as_tuple(k) + pred = lambda *args: args == k + o = tuple(replace( + o, pred=pred, substitutes=as_tuple(handle), window_size=len(k) + )) if k in o and is_iterable(handle): # Replace k by the iterable that is provided by handle o, i = _inject_handle(o, 0, k, handle) From 67100a192122a4eff9d2d0fb831df5989eaf9ce9 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 9 Feb 2024 15:53:14 +0000 Subject: [PATCH 3/6] Frontend: Perform comment clustering in-place Instead of a costly find-and-replace pattern, we re-write the utility as `Transformer` and re-assemble the tuple in-place. This severely improves parsing time for large files --- loki/frontend/util.py | 60 ++++++++++++++++++++++++----------------- tests/test_frontends.py | 51 ++++++++++++++++++++++++++++++++++- 2 files changed, 85 insertions(+), 26 deletions(-) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 62578dc46..a7d7d26a3 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -9,8 +9,12 @@ from pathlib import Path import codecs from codetiming import Timer +from itertools import groupby +from more_itertools import replace -from loki.visitors import NestedTransformer, FindNodes, PatternFinder, SequenceFinder +from loki.visitors import ( + NestedTransformer, FindNodes, PatternFinder, SequenceFinder, Transformer +) from loki.ir import ( Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration, Loop, Intrinsic, Pragma @@ -19,8 +23,8 @@ from loki.logging import warning, perf __all__ = [ - 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', - 'inline_comments', 'cluster_comments', 'read_file', + 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', 'inline_comments', + 'ClusterCommentTransformer', 'read_file', 'combine_multiline_pragmas', 'sanitize_ir' ] @@ -66,28 +70,34 @@ def inline_comments(ir): return NestedTransformer(mapper, invalidate_source=False).visit(ir) -def cluster_comments(ir): +class ClusterCommentTransformer(Transformer): """ - Cluster comments into comment blocks + Combines consecutive sets of :any:`Comment` into a :any:`CommentBlock`. """ - comment_mapper = {} - comment_groups = SequenceFinder(node_type=Comment).visit(ir) - for comments in comment_groups: - # Build a CommentBlock and map it to first comment - # and map remaining comments to None for removal - if all(c.source is not None for c in comments): - if all(c.source.string is not None for c in comments): - string = '\n'.join(c.source.string for c in comments) - else: - string = None - lines = {l for c in comments for l in c.source.lines if l is not None} - lines = (min(lines), max(lines)) - source = Source(lines=lines, string=string, file=comments[0].source.file) - else: - source = None - block = CommentBlock(comments, label=comments[0].label, source=source) - comment_mapper[comments] = block - return NestedTransformer(comment_mapper, invalidate_source=False).visit(ir) + + def visit_tuple(self, o, **kwargs): + """ + Find groups of :any:`Comment` and inject into the tuple. + """ + cgroups = tuple( + tuple(g) for k, g in groupby(o, key=lambda x: x.__class__) + if k == Comment + ) + cgroups = tuple(g for g in cgroups if len(g) > 1) + + for group in cgroups: + # Combine the group into a CommentBlock + source = join_source_list(tuple(p.source for p in group)) + block = CommentBlock(comments=group, label=group[0].label, source=source) + pred = lambda *args: args == group + o = tuple(replace( + o, pred=pred, substitutes=(block,), window_size=len(group) + )) + + # Then recurse over the new nodes + return tuple(self.visit(i, **kwargs) for i in o) + + visit_list = visit_tuple def inline_labels(ir): @@ -191,7 +201,7 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): the following operations: * :any:`inline_comments` to attach inline-comments to IR nodes - * :any:`cluster_comments` to combine multi-line comments into :any:`CommentBlock` + * :any:`ClusterCommentTransformer` to combine multi-line comments into :any:`CommentBlock` * :any:`combine_multiline_pragmas` to combine multi-line pragmas into a single node @@ -215,7 +225,7 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): # Perform some minor sanitation tasks _ir = inline_comments(_ir) - _ir = cluster_comments(_ir) + _ir = ClusterCommentTransformer(invalidate_source=False).visit(_ir) if frontend in (OMNI, OFP): _ir = inline_labels(_ir) diff --git a/tests/test_frontends.py b/tests/test_frontends.py index ab08cad15..1dd1d6df3 100644 --- a/tests/test_frontends.py +++ b/tests/test_frontends.py @@ -25,7 +25,7 @@ Deallocation, Associate, BasicType, OMNI, OFP, FP, Enumeration, config, REGEX, Sourcefile, Import, RawSource, CallStatement, RegexParserClass, ProcedureType, DerivedType, Comment, Pragma, - PreprocessorDirective, config_override, Section + PreprocessorDirective, config_override, Section, CommentBlock ) from loki.expression import symbols as sym @@ -1698,3 +1698,52 @@ def test_pragma_line_continuation(frontend): assert 'PRESENT' in pragmas[0].content assert 'PRIVATE' in pragmas[0].content assert 'VECTOR_LENGTH' in pragmas[0].content + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_comment_block_clustering(frontend): + """ + Test that multiple :any:`Comment` nodes into a :any:`CommentBlock`. + """ + fcode = """ +subroutine test_comment_block(a, b) + ! What is this? + ! Ohhh, ... a docstring? + real, intent(inout) :: a, b + + a = a + 1.0 + ! Never gonna + b = b + 2 + ! give you + ! up... + + a = a + b + ! Shut up, ... + ! Rick! +end subroutine test_comment_block +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + comments = FindNodes(Comment).visit(routine.spec) + assert len(comments) == 0 + blocks = FindNodes(CommentBlock).visit(routine.spec) + assert len(blocks) == 0 + + assert isinstance(routine.docstring[0], CommentBlock) + assert len(routine.docstring[0].comments) == 2 + assert routine.docstring[0].comments[0].text == '! What is this?' + assert routine.docstring[0].comments[1].text == '! Ohhh, ... a docstring?' + + comments = FindNodes(Comment).visit(routine.body) + assert len(comments) == 2 if frontend == FP else 1 + assert comments[-1].text == '! Never gonna' + + blocks = FindNodes(CommentBlock).visit(routine.body) + assert len(blocks) == 2 + assert len(blocks[0].comments) == 3 if frontend == FP else 2 + assert blocks[0].comments[0].text == '! give you' + assert blocks[0].comments[1].text == '! up...' + + assert len(blocks[1].comments) == 2 + assert blocks[1].comments[0].text == '! Shut up, ...' + assert blocks[1].comments[1].text == '! Rick!' From 53419fc2e334e9cb7769ecc08f9f3480806f6c0b Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Fri, 9 Feb 2024 18:18:28 +0000 Subject: [PATCH 4/6] Frontend: Log performance-sensitive parts of frontends under perf --- loki/frontend/fparser.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 8aee166c2..4195f8819 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -31,7 +31,7 @@ StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow ) from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper -from loki.logging import debug, info, warning, error +from loki.logging import debug, perf, info, warning, error from loki.tools import as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup from loki.pragma_utils import ( attach_pragmas, process_dimension_pragmas, detach_pragmas, pragmas_attached @@ -55,7 +55,7 @@ def parse_fparser_file(filename): return parse_fparser_source(source=fcode) -@Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s') +@Timer(logger=perf, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s') def parse_fparser_source(source): """ Generate a parse tree from string @@ -77,7 +77,7 @@ def parse_fparser_source(source): return f2008_parser(reader) -@Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s') +@Timer(logger=perf, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s') def parse_fparser_ast(ast, raw_source, pp_info=None, definitions=None, scope=None): """ Generate an internal IR from fparser parse tree From d45767f3c71f4a6d3b8418bdaeca2c60949e688b Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Sat, 10 Feb 2024 04:12:18 +0000 Subject: [PATCH 5/6] Frontend: Combine multi-line pragmas inplace with Transformer The previous search-and-replace pattern was very costly, so we now do this with a custom `Transformer` that applies this all directly. --- loki/frontend/source.py | 2 +- loki/frontend/util.py | 80 ++++++++++++++++++----------------------- tests/test_frontends.py | 11 +++++- tests/test_source.py | 28 +++++++-------- 4 files changed, 59 insertions(+), 62 deletions(-) diff --git a/loki/frontend/source.py b/loki/frontend/source.py index 1ad256dc1..49caa08bc 100644 --- a/loki/frontend/source.py +++ b/loki/frontend/source.py @@ -519,4 +519,4 @@ def join_source_list(source_list): newlines = 0 string += '\n' * newlines + source.string lines[1] = source.lines[1] if source.lines[1] else lines[1] + newlines + source.string.count('\n') - return Source(lines, string, source_list[0].file) + return Source(tuple(lines), string, source_list[0].file) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index a7d7d26a3..862a954f2 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -10,7 +10,7 @@ import codecs from codetiming import Timer from itertools import groupby -from more_itertools import replace +from more_itertools import replace, split_after from loki.visitors import ( NestedTransformer, FindNodes, PatternFinder, SequenceFinder, Transformer @@ -19,13 +19,13 @@ Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration, Loop, Intrinsic, Pragma ) -from loki.frontend.source import Source +from loki.frontend.source import Source, join_source_list from loki.logging import warning, perf __all__ = [ 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', 'inline_comments', 'ClusterCommentTransformer', 'read_file', - 'combine_multiline_pragmas', 'sanitize_ir' + 'CombineMultilinePragmasTransformer', 'sanitize_ir' ] @@ -146,49 +146,37 @@ def read_file(file_path): return source -def combine_multiline_pragmas(ir): +class CombineMultilinePragmasTransformer(Transformer): """ - Combine multiline pragmas into single pragma nodes + Combine multiline :any:`Pragma` nodes into single ones. """ - pragma_mapper = {} - pragma_groups = SequenceFinder(node_type=Pragma).visit(ir) - for pragma_list in pragma_groups: - collected_pragmas = [] - for pragma in pragma_list: - if not collected_pragmas: - if pragma.content.rstrip().endswith('&'): - # This is the beginning of a multiline pragma - collected_pragmas = [pragma] - else: - # This is the continuation of a multiline pragma - collected_pragmas += [pragma] - - if pragma.keyword != collected_pragmas[0].keyword: - raise RuntimeError('Pragma keyword mismatch after line continuation: ' + - f'{collected_pragmas[0].keyword} != {pragma.keyword}') - - if not pragma.content.rstrip().endswith('&'): - # This is the last line of a multiline pragma - content = [p.content.strip()[:-1].rstrip() for p in collected_pragmas[:-1]] - content = ' '.join(content) + ' ' + pragma.content.strip() - - if all(p.source is not None for p in collected_pragmas): - if all(p.source.string is not None for p in collected_pragmas): - string = '\n'.join(p.source.string for p in collected_pragmas) - else: - string = None - lines = (collected_pragmas[0].source.lines[0], collected_pragmas[-1].source.lines[1]) - source = Source(lines=lines, string=string, file=pragma.source.file) - else: - source = None - - new_pragma = Pragma(keyword=pragma.keyword, content=content, source=source) - pragma_mapper[collected_pragmas[0]] = new_pragma - pragma_mapper.update({p: None for p in collected_pragmas[1:]}) - - collected_pragmas = [] - - return NestedTransformer(pragma_mapper, invalidate_source=False).visit(ir) + + def visit_tuple(self, o, **kwargs): + """ + Finds multi-line pragmas and combines them in-place. + """ + pgroups = tuple( + tuple(g) for k, g in groupby(o, key=lambda x: x.__class__) + if k == Pragma + ) + pgroups = tuple(g for g in pgroups if len(g) > 1) + + for group in pgroups: + # Separate sets of consecutive multi-line pragmas + pred = lambda p: not p.content.rstrip().endswith('&') # pylint: disable=unnecessary-lambda-assignment + for pragmaset in split_after(group, pred=pred): + # Combine into a single pragma and add to map + source = join_source_list(tuple(p.source for p in pragmaset)) + content = ' '.join(p.content.rstrip(' &') for p in pragmaset) + new_pragma = Pragma( + keyword=pragmaset[0].keyword, content=content, source=source + ) + pred = lambda *args: args == tuple(pragmaset) + o = tuple(replace( + o, pred=pred, substitutes=(new_pragma,), window_size=len(pragmaset) + )) + + return tuple(self.visit(i, **kwargs) for i in o) @Timer(logger=perf, text=lambda s: f'[Loki::Frontend] Executed sanitize_ir in {s:.2f}s') @@ -202,7 +190,7 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): * :any:`inline_comments` to attach inline-comments to IR nodes * :any:`ClusterCommentTransformer` to combine multi-line comments into :any:`CommentBlock` - * :any:`combine_multiline_pragmas` to combine multi-line pragmas into a + * :any:`CombineMultilinePragmasTransformer` to combine multi-line pragmas into a single node Parameters @@ -231,6 +219,6 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): _ir = inline_labels(_ir) if frontend in (FP, OFP): - _ir = combine_multiline_pragmas(_ir) + _ir = CombineMultilinePragmasTransformer(invalidate_source=False).visit(_ir) return _ir diff --git a/tests/test_frontends.py b/tests/test_frontends.py index 1dd1d6df3..0b21da291 100644 --- a/tests/test_frontends.py +++ b/tests/test_frontends.py @@ -1686,18 +1686,27 @@ def test_pragma_line_continuation(frontend): !$ACC& PRESENT(ZRDG_LCVQ,ZFLU_QSATS,ZRDG_CVGQ) & !$ACC& PRIVATE (JBLK) & !$ACC& VECTOR_LENGTH (YDCPG_OPTS%KLON) +!$ACC SEQUENTIAL END SUBROUTINE TOTO """ routine = Subroutine.from_source(fcode, frontend=frontend) pragmas = FindNodes(Pragma).visit(routine.body) - assert len(pragmas) == 1 + + assert len(pragmas) == 2 assert pragmas[0].keyword == 'ACC' assert 'PARALLEL' in pragmas[0].content assert 'PRESENT' in pragmas[0].content assert 'PRIVATE' in pragmas[0].content assert 'VECTOR_LENGTH' in pragmas[0].content + assert pragmas[1].content == 'SEQUENTIAL' + + # Check that source object was generated right + assert pragmas[0].source + assert pragmas[0].source.lines == (8, 8) if frontend == OMNI else (8, 11) + assert pragmas[1].source + assert pragmas[1].source.lines == (12, 12) @pytest.mark.parametrize('frontend', available_frontends()) diff --git a/tests/test_source.py b/tests/test_source.py index 58de2be1e..e797c5e8a 100644 --- a/tests/test_source.py +++ b/tests/test_source.py @@ -223,8 +223,8 @@ def test_source_to_lines(): ( [], None ), ( - [Source([1, 2], 'subroutine my_routine\nimplicit none'), Source([3, None], 'end subroutine my_routine')], - Source([1, 3], 'subroutine my_routine\nimplicit none\nend subroutine my_routine') + [Source((1, 2), 'subroutine my_routine\nimplicit none'), Source((3, None), 'end subroutine my_routine')], + Source((1, 3), 'subroutine my_routine\nimplicit none\nend subroutine my_routine') ), ( [ Source([1, None], 'subroutine my_routine'), @@ -235,7 +235,7 @@ def test_source_to_lines(): Source([6, 7], ' var_1 = 1._real64\n var_2 = 2._real64'), Source([8, None], 'end subroutine my_routine'), ], - Source([1, 8], ''' + Source((1, 8), ''' subroutine my_routine use iso_fortran_env, only: real64 implicit none @@ -247,22 +247,22 @@ def test_source_to_lines(): '''.strip()) ), ( [ - Source([5, 5], 'integer ::'), - Source([5, None], ' var1,'), - Source([5, 5], ' var2') + Source((5, 5), 'integer ::'), + Source((5, None), ' var1,'), + Source((5, 5), ' var2') ], - Source([5, 5], 'integer :: var1, var2') + Source((5, 5), 'integer :: var1, var2') ), ( - [Source([1, 1], 'print *,* "hello world!"')], Source([1, 1], 'print *,* "hello world!"') + [Source((1, 1), 'print *,* "hello world!"')], Source((1, 1), 'print *,* "hello world!"') ), ( - [Source([13, 19], '! line with less line breaks than reported'), Source([20, None], '! here')], - Source([13, 20], '! line with less line breaks than reported\n! here') + [Source((13, 19), '! line with less line breaks than reported'), Source((20, None), '! here')], + Source((13, 20), '! line with less line breaks than reported\n! here') ), ( - [Source([7, None], '! Some line'), Source([12, None], '! Some other line')], - Source([7, 12], '! Some line\n\n\n\n\n! Some other line') + [Source((7, None), '! Some line'), Source([12, None], '! Some other line')], + Source((7, 12), '! Some line\n\n\n\n\n! Some other line') ), ( - [Source([3, 4], '! Some line\n! With line break'), Source([6, None], '! Other line\n! And new line')], - Source([3, 7], '! Some line\n! With line break\n\n! Other line\n! And new line') + [Source((3, 4), '! Some line\n! With line break'), Source([6, None], '! Other line\n! And new line')], + Source((3, 7), '! Some line\n! With line break\n\n! Other line\n! And new line') ) )) def test_join_source_list(source_list, expected): From e8ba69eeaa6db8852bdc45ed0f344f994d29cc2a Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Sat, 10 Feb 2024 04:37:25 +0000 Subject: [PATCH 6/6] Tools: Move tuple-based find and replace patterns to utilities --- loki/frontend/util.py | 32 ++++++++-------------------- loki/tools/util.py | 43 +++++++++++++++++++++++++++++++++++++- loki/visitors/transform.py | 10 ++------- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index 862a954f2..cc643ec9e 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -9,18 +9,19 @@ from pathlib import Path import codecs from codetiming import Timer -from itertools import groupby -from more_itertools import replace, split_after +from more_itertools import split_after from loki.visitors import ( - NestedTransformer, FindNodes, PatternFinder, SequenceFinder, Transformer + NestedTransformer, FindNodes, PatternFinder, Transformer ) from loki.ir import ( Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration, Loop, Intrinsic, Pragma ) -from loki.frontend.source import Source, join_source_list +from loki.frontend.source import join_source_list from loki.logging import warning, perf +from loki.tools import group_by_class, replace_windowed + __all__ = [ 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', 'inline_comments', @@ -79,20 +80,12 @@ def visit_tuple(self, o, **kwargs): """ Find groups of :any:`Comment` and inject into the tuple. """ - cgroups = tuple( - tuple(g) for k, g in groupby(o, key=lambda x: x.__class__) - if k == Comment - ) - cgroups = tuple(g for g in cgroups if len(g) > 1) - + cgroups = group_by_class(o, Comment) for group in cgroups: # Combine the group into a CommentBlock source = join_source_list(tuple(p.source for p in group)) block = CommentBlock(comments=group, label=group[0].label, source=source) - pred = lambda *args: args == group - o = tuple(replace( - o, pred=pred, substitutes=(block,), window_size=len(group) - )) + o = replace_windowed(o, group, subs=(block,)) # Then recurse over the new nodes return tuple(self.visit(i, **kwargs) for i in o) @@ -155,11 +148,7 @@ def visit_tuple(self, o, **kwargs): """ Finds multi-line pragmas and combines them in-place. """ - pgroups = tuple( - tuple(g) for k, g in groupby(o, key=lambda x: x.__class__) - if k == Pragma - ) - pgroups = tuple(g for g in pgroups if len(g) > 1) + pgroups = group_by_class(o, Pragma) for group in pgroups: # Separate sets of consecutive multi-line pragmas @@ -171,10 +160,7 @@ def visit_tuple(self, o, **kwargs): new_pragma = Pragma( keyword=pragmaset[0].keyword, content=content, source=source ) - pred = lambda *args: args == tuple(pragmaset) - o = tuple(replace( - o, pred=pred, substitutes=(new_pragma,), window_size=len(pragmaset) - )) + o = replace_windowed(o, pragmaset, subs=(new_pragma,)) return tuple(self.visit(i, **kwargs) for i in o) diff --git a/loki/tools/util.py b/loki/tools/util.py index 3341cc371..cb56740b7 100644 --- a/loki/tools/util.py +++ b/loki/tools/util.py @@ -16,6 +16,8 @@ from subprocess import run, PIPE, STDOUT, CalledProcessError from contextlib import contextmanager from pathlib import Path +from itertools import groupby +from more_itertools import replace try: import yaml @@ -31,7 +33,7 @@ 'execute', 'CaseInsensitiveDict', 'strip_inline_comments', 'binary_insertion_sort', 'cached_func', 'optional', 'LazyNodeLookup', 'yaml_include_constructor', 'auto_post_mortem_debugger', 'set_excepthook', - 'timeout', 'WeakrefProperty' + 'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed' ] @@ -628,3 +630,42 @@ def __set__(self, obj, value): obj.__dict__[self._name] = value else: setattr(obj, self._name, value) + + +def group_by_class(iterable, klass): + """ + Find groups of consecutive instances of the same type with more + than one element. + + Parameters + ---------- + iterable : iterable + Input iterable from which to extract groups + klass : type + Type by which to group elements in the given iterable + """ + groups = tuple( + tuple(g) for k, g in groupby(iterable, key=lambda x: x.__class__) + if k == klass + ) + return tuple(g for g in groups if len(g) > 1) + + +def replace_windowed(iterable, group, subs): + """ + Replace a set of consecutive elements in a larger iterable. + + Parameters + ---------- + iterable : iterable + Input iterable in which to replace elements + group : iterable + Group of elements to replace in ``iterable`` + subs : any + Replacement for ``group`` in ``iterable`` + """ + group = as_tuple(group) + return tuple(replace( + iterable, pred=lambda *args: args == group, + substitutes=as_tuple(subs), window_size=len(group) + )) diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index 26290ad0b..64c1d5a56 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -8,10 +8,8 @@ """ Visitor classes for transforming the IR """ -from more_itertools import replace - from loki.ir import Node, Conditional, ScopedNode -from loki.tools import flatten, is_iterable, as_tuple +from loki.tools import flatten, is_iterable, as_tuple, replace_windowed from loki.visitors.visitor import Visitor @@ -137,11 +135,7 @@ def _inject_handle(nodes, i, old, new): for k, handle in self.mapper.items(): if is_iterable(k): - k = as_tuple(k) - pred = lambda *args: args == k - o = tuple(replace( - o, pred=pred, substitutes=as_tuple(handle), window_size=len(k) - )) + o = replace_windowed(o, k, subs=handle) if k in o and is_iterable(handle): # Replace k by the iterable that is provided by handle o, i = _inject_handle(o, 0, k, handle)