diff --git a/loki/frontend/fparser.py b/loki/frontend/fparser.py index 8aee166c2..4195f8819 100644 --- a/loki/frontend/fparser.py +++ b/loki/frontend/fparser.py @@ -31,7 +31,7 @@ StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow ) from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper -from loki.logging import debug, info, warning, error +from loki.logging import debug, perf, info, warning, error from loki.tools import as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup from loki.pragma_utils import ( attach_pragmas, process_dimension_pragmas, detach_pragmas, pragmas_attached @@ -55,7 +55,7 @@ def parse_fparser_file(filename): return parse_fparser_source(source=fcode) -@Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s') +@Timer(logger=perf, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s') def parse_fparser_source(source): """ Generate a parse tree from string @@ -77,7 +77,7 @@ def parse_fparser_source(source): return f2008_parser(reader) -@Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s') +@Timer(logger=perf, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s') def parse_fparser_ast(ast, raw_source, pp_info=None, definitions=None, scope=None): """ Generate an internal IR from fparser parse tree diff --git a/loki/frontend/source.py b/loki/frontend/source.py index 1ad256dc1..49caa08bc 100644 --- a/loki/frontend/source.py +++ b/loki/frontend/source.py @@ -519,4 +519,4 @@ def join_source_list(source_list): newlines = 0 string += '\n' * newlines + source.string lines[1] = source.lines[1] if source.lines[1] else lines[1] + newlines + source.string.count('\n') - return Source(lines, string, source_list[0].file) + return Source(tuple(lines), string, source_list[0].file) diff --git a/loki/frontend/util.py b/loki/frontend/util.py index fd68a1d13..cc643ec9e 100644 --- a/loki/frontend/util.py +++ b/loki/frontend/util.py @@ -8,19 +8,25 @@ from enum import IntEnum from pathlib import Path import codecs +from codetiming import Timer +from more_itertools import split_after -from loki.visitors import NestedTransformer, FindNodes, PatternFinder, SequenceFinder +from loki.visitors import ( + NestedTransformer, FindNodes, PatternFinder, Transformer +) from loki.ir import ( Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration, Loop, Intrinsic, Pragma ) -from loki.frontend.source import Source -from loki.logging import warning +from loki.frontend.source import join_source_list +from loki.logging import warning, perf +from loki.tools import group_by_class, replace_windowed + __all__ = [ - 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', - 'inline_comments', 'cluster_comments', 'read_file', - 'combine_multiline_pragmas', 'sanitize_ir' + 'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', 'inline_comments', + 'ClusterCommentTransformer', 'read_file', + 'CombineMultilinePragmasTransformer', 'sanitize_ir' ] @@ -65,28 +71,26 @@ def inline_comments(ir): return NestedTransformer(mapper, invalidate_source=False).visit(ir) -def cluster_comments(ir): +class ClusterCommentTransformer(Transformer): """ - Cluster comments into comment blocks + Combines consecutive sets of :any:`Comment` into a :any:`CommentBlock`. """ - comment_mapper = {} - comment_groups = SequenceFinder(node_type=Comment).visit(ir) - for comments in comment_groups: - # Build a CommentBlock and map it to first comment - # and map remaining comments to None for removal - if all(c.source is not None for c in comments): - if all(c.source.string is not None for c in comments): - string = '\n'.join(c.source.string for c in comments) - else: - string = None - lines = {l for c in comments for l in c.source.lines if l is not None} - lines = (min(lines), max(lines)) - source = Source(lines=lines, string=string, file=comments[0].source.file) - else: - source = None - block = CommentBlock(comments, label=comments[0].label, source=source) - comment_mapper[comments] = block - return NestedTransformer(comment_mapper, invalidate_source=False).visit(ir) + + def visit_tuple(self, o, **kwargs): + """ + Find groups of :any:`Comment` and inject into the tuple. + """ + cgroups = group_by_class(o, Comment) + for group in cgroups: + # Combine the group into a CommentBlock + source = join_source_list(tuple(p.source for p in group)) + block = CommentBlock(comments=group, label=group[0].label, source=source) + o = replace_windowed(o, group, subs=(block,)) + + # Then recurse over the new nodes + return tuple(self.visit(i, **kwargs) for i in o) + + visit_list = visit_tuple def inline_labels(ir): @@ -135,51 +139,33 @@ def read_file(file_path): return source -def combine_multiline_pragmas(ir): +class CombineMultilinePragmasTransformer(Transformer): """ - Combine multiline pragmas into single pragma nodes + Combine multiline :any:`Pragma` nodes into single ones. """ - pragma_mapper = {} - pragma_groups = SequenceFinder(node_type=Pragma).visit(ir) - for pragma_list in pragma_groups: - collected_pragmas = [] - for pragma in pragma_list: - if not collected_pragmas: - if pragma.content.rstrip().endswith('&'): - # This is the beginning of a multiline pragma - collected_pragmas = [pragma] - else: - # This is the continuation of a multiline pragma - collected_pragmas += [pragma] - - if pragma.keyword != collected_pragmas[0].keyword: - raise RuntimeError('Pragma keyword mismatch after line continuation: ' + - f'{collected_pragmas[0].keyword} != {pragma.keyword}') - - if not pragma.content.rstrip().endswith('&'): - # This is the last line of a multiline pragma - content = [p.content.strip()[:-1].rstrip() for p in collected_pragmas[:-1]] - content = ' '.join(content) + ' ' + pragma.content.strip() - - if all(p.source is not None for p in collected_pragmas): - if all(p.source.string is not None for p in collected_pragmas): - string = '\n'.join(p.source.string for p in collected_pragmas) - else: - string = None - lines = (collected_pragmas[0].source.lines[0], collected_pragmas[-1].source.lines[1]) - source = Source(lines=lines, string=string, file=pragma.source.file) - else: - source = None - - new_pragma = Pragma(keyword=pragma.keyword, content=content, source=source) - pragma_mapper[collected_pragmas[0]] = new_pragma - pragma_mapper.update({p: None for p in collected_pragmas[1:]}) - - collected_pragmas = [] - - return NestedTransformer(pragma_mapper, invalidate_source=False).visit(ir) + + def visit_tuple(self, o, **kwargs): + """ + Finds multi-line pragmas and combines them in-place. + """ + pgroups = group_by_class(o, Pragma) + + for group in pgroups: + # Separate sets of consecutive multi-line pragmas + pred = lambda p: not p.content.rstrip().endswith('&') # pylint: disable=unnecessary-lambda-assignment + for pragmaset in split_after(group, pred=pred): + # Combine into a single pragma and add to map + source = join_source_list(tuple(p.source for p in pragmaset)) + content = ' '.join(p.content.rstrip(' &') for p in pragmaset) + new_pragma = Pragma( + keyword=pragmaset[0].keyword, content=content, source=source + ) + o = replace_windowed(o, pragmaset, subs=(new_pragma,)) + + return tuple(self.visit(i, **kwargs) for i in o) +@Timer(logger=perf, text=lambda s: f'[Loki::Frontend] Executed sanitize_ir in {s:.2f}s') def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): """ Utility function to sanitize internal representation after creating it @@ -189,8 +175,8 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): the following operations: * :any:`inline_comments` to attach inline-comments to IR nodes - * :any:`cluster_comments` to combine multi-line comments into :any:`CommentBlock` - * :any:`combine_multiline_pragmas` to combine multi-line pragmas into a + * :any:`ClusterCommentTransformer` to combine multi-line comments into :any:`CommentBlock` + * :any:`CombineMultilinePragmasTransformer` to combine multi-line pragmas into a single node Parameters @@ -213,12 +199,12 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None): # Perform some minor sanitation tasks _ir = inline_comments(_ir) - _ir = cluster_comments(_ir) + _ir = ClusterCommentTransformer(invalidate_source=False).visit(_ir) if frontend in (OMNI, OFP): _ir = inline_labels(_ir) if frontend in (FP, OFP): - _ir = combine_multiline_pragmas(_ir) + _ir = CombineMultilinePragmasTransformer(invalidate_source=False).visit(_ir) return _ir diff --git a/loki/tools/util.py b/loki/tools/util.py index 3341cc371..cb56740b7 100644 --- a/loki/tools/util.py +++ b/loki/tools/util.py @@ -16,6 +16,8 @@ from subprocess import run, PIPE, STDOUT, CalledProcessError from contextlib import contextmanager from pathlib import Path +from itertools import groupby +from more_itertools import replace try: import yaml @@ -31,7 +33,7 @@ 'execute', 'CaseInsensitiveDict', 'strip_inline_comments', 'binary_insertion_sort', 'cached_func', 'optional', 'LazyNodeLookup', 'yaml_include_constructor', 'auto_post_mortem_debugger', 'set_excepthook', - 'timeout', 'WeakrefProperty' + 'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed' ] @@ -628,3 +630,42 @@ def __set__(self, obj, value): obj.__dict__[self._name] = value else: setattr(obj, self._name, value) + + +def group_by_class(iterable, klass): + """ + Find groups of consecutive instances of the same type with more + than one element. + + Parameters + ---------- + iterable : iterable + Input iterable from which to extract groups + klass : type + Type by which to group elements in the given iterable + """ + groups = tuple( + tuple(g) for k, g in groupby(iterable, key=lambda x: x.__class__) + if k == klass + ) + return tuple(g for g in groups if len(g) > 1) + + +def replace_windowed(iterable, group, subs): + """ + Replace a set of consecutive elements in a larger iterable. + + Parameters + ---------- + iterable : iterable + Input iterable in which to replace elements + group : iterable + Group of elements to replace in ``iterable`` + subs : any + Replacement for ``group`` in ``iterable`` + """ + group = as_tuple(group) + return tuple(replace( + iterable, pred=lambda *args: args == group, + substitutes=as_tuple(subs), window_size=len(group) + )) diff --git a/loki/visitors/transform.py b/loki/visitors/transform.py index 5505596c1..64c1d5a56 100644 --- a/loki/visitors/transform.py +++ b/loki/visitors/transform.py @@ -8,10 +8,8 @@ """ Visitor classes for transforming the IR """ -from more_itertools import windowed - from loki.ir import Node, Conditional, ScopedNode -from loki.tools import flatten, is_iterable, as_tuple +from loki.tools import flatten, is_iterable, as_tuple, replace_windowed from loki.visitors.visitor import Visitor @@ -137,10 +135,7 @@ def _inject_handle(nodes, i, old, new): for k, handle in self.mapper.items(): if is_iterable(k): - w = list(windowed(o, len(k))) - if k in w: - i = list(w).index(k) - o = o[:i] + as_tuple(handle) + o[i+len(k):] + o = replace_windowed(o, k, subs=handle) if k in o and is_iterable(handle): # Replace k by the iterable that is provided by handle o, i = _inject_handle(o, 0, k, handle) diff --git a/tests/test_frontends.py b/tests/test_frontends.py index ab08cad15..0b21da291 100644 --- a/tests/test_frontends.py +++ b/tests/test_frontends.py @@ -25,7 +25,7 @@ Deallocation, Associate, BasicType, OMNI, OFP, FP, Enumeration, config, REGEX, Sourcefile, Import, RawSource, CallStatement, RegexParserClass, ProcedureType, DerivedType, Comment, Pragma, - PreprocessorDirective, config_override, Section + PreprocessorDirective, config_override, Section, CommentBlock ) from loki.expression import symbols as sym @@ -1686,15 +1686,73 @@ def test_pragma_line_continuation(frontend): !$ACC& PRESENT(ZRDG_LCVQ,ZFLU_QSATS,ZRDG_CVGQ) & !$ACC& PRIVATE (JBLK) & !$ACC& VECTOR_LENGTH (YDCPG_OPTS%KLON) +!$ACC SEQUENTIAL END SUBROUTINE TOTO """ routine = Subroutine.from_source(fcode, frontend=frontend) pragmas = FindNodes(Pragma).visit(routine.body) - assert len(pragmas) == 1 + + assert len(pragmas) == 2 assert pragmas[0].keyword == 'ACC' assert 'PARALLEL' in pragmas[0].content assert 'PRESENT' in pragmas[0].content assert 'PRIVATE' in pragmas[0].content assert 'VECTOR_LENGTH' in pragmas[0].content + assert pragmas[1].content == 'SEQUENTIAL' + + # Check that source object was generated right + assert pragmas[0].source + assert pragmas[0].source.lines == (8, 8) if frontend == OMNI else (8, 11) + assert pragmas[1].source + assert pragmas[1].source.lines == (12, 12) + + +@pytest.mark.parametrize('frontend', available_frontends()) +def test_comment_block_clustering(frontend): + """ + Test that multiple :any:`Comment` nodes into a :any:`CommentBlock`. + """ + fcode = """ +subroutine test_comment_block(a, b) + ! What is this? + ! Ohhh, ... a docstring? + real, intent(inout) :: a, b + + a = a + 1.0 + ! Never gonna + b = b + 2 + ! give you + ! up... + + a = a + b + ! Shut up, ... + ! Rick! +end subroutine test_comment_block +""" + routine = Subroutine.from_source(fcode, frontend=frontend) + + comments = FindNodes(Comment).visit(routine.spec) + assert len(comments) == 0 + blocks = FindNodes(CommentBlock).visit(routine.spec) + assert len(blocks) == 0 + + assert isinstance(routine.docstring[0], CommentBlock) + assert len(routine.docstring[0].comments) == 2 + assert routine.docstring[0].comments[0].text == '! What is this?' + assert routine.docstring[0].comments[1].text == '! Ohhh, ... a docstring?' + + comments = FindNodes(Comment).visit(routine.body) + assert len(comments) == 2 if frontend == FP else 1 + assert comments[-1].text == '! Never gonna' + + blocks = FindNodes(CommentBlock).visit(routine.body) + assert len(blocks) == 2 + assert len(blocks[0].comments) == 3 if frontend == FP else 2 + assert blocks[0].comments[0].text == '! give you' + assert blocks[0].comments[1].text == '! up...' + + assert len(blocks[1].comments) == 2 + assert blocks[1].comments[0].text == '! Shut up, ...' + assert blocks[1].comments[1].text == '! Rick!' diff --git a/tests/test_source.py b/tests/test_source.py index 58de2be1e..e797c5e8a 100644 --- a/tests/test_source.py +++ b/tests/test_source.py @@ -223,8 +223,8 @@ def test_source_to_lines(): ( [], None ), ( - [Source([1, 2], 'subroutine my_routine\nimplicit none'), Source([3, None], 'end subroutine my_routine')], - Source([1, 3], 'subroutine my_routine\nimplicit none\nend subroutine my_routine') + [Source((1, 2), 'subroutine my_routine\nimplicit none'), Source((3, None), 'end subroutine my_routine')], + Source((1, 3), 'subroutine my_routine\nimplicit none\nend subroutine my_routine') ), ( [ Source([1, None], 'subroutine my_routine'), @@ -235,7 +235,7 @@ def test_source_to_lines(): Source([6, 7], ' var_1 = 1._real64\n var_2 = 2._real64'), Source([8, None], 'end subroutine my_routine'), ], - Source([1, 8], ''' + Source((1, 8), ''' subroutine my_routine use iso_fortran_env, only: real64 implicit none @@ -247,22 +247,22 @@ def test_source_to_lines(): '''.strip()) ), ( [ - Source([5, 5], 'integer ::'), - Source([5, None], ' var1,'), - Source([5, 5], ' var2') + Source((5, 5), 'integer ::'), + Source((5, None), ' var1,'), + Source((5, 5), ' var2') ], - Source([5, 5], 'integer :: var1, var2') + Source((5, 5), 'integer :: var1, var2') ), ( - [Source([1, 1], 'print *,* "hello world!"')], Source([1, 1], 'print *,* "hello world!"') + [Source((1, 1), 'print *,* "hello world!"')], Source((1, 1), 'print *,* "hello world!"') ), ( - [Source([13, 19], '! line with less line breaks than reported'), Source([20, None], '! here')], - Source([13, 20], '! line with less line breaks than reported\n! here') + [Source((13, 19), '! line with less line breaks than reported'), Source((20, None), '! here')], + Source((13, 20), '! line with less line breaks than reported\n! here') ), ( - [Source([7, None], '! Some line'), Source([12, None], '! Some other line')], - Source([7, 12], '! Some line\n\n\n\n\n! Some other line') + [Source((7, None), '! Some line'), Source([12, None], '! Some other line')], + Source((7, 12), '! Some line\n\n\n\n\n! Some other line') ), ( - [Source([3, 4], '! Some line\n! With line break'), Source([6, None], '! Other line\n! And new line')], - Source([3, 7], '! Some line\n! With line break\n\n! Other line\n! And new line') + [Source((3, 4), '! Some line\n! With line break'), Source([6, None], '! Other line\n! And new line')], + Source((3, 7), '! Some line\n! With line break\n\n! Other line\n! And new line') ) )) def test_join_source_list(source_list, expected):