Skip to content

Commit

Permalink
Merge pull request #229 from ecmwf-ifs/naml-frontend-optimisations
Browse files Browse the repository at this point in the history
Performance optimisations for frontend parsing/sanitising
  • Loading branch information
reuterbal authored Feb 21, 2024
2 parents f25e375 + e8ba69e commit 22abc5c
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 99 deletions.
6 changes: 3 additions & 3 deletions loki/frontend/fparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
StringConcat, ParenthesisedAdd, ParenthesisedMul, ParenthesisedDiv, ParenthesisedPow
)
from loki.expression import ExpressionDimensionsMapper, AttachScopes, AttachScopesMapper
from loki.logging import debug, info, warning, error
from loki.logging import debug, perf, info, warning, error
from loki.tools import as_tuple, flatten, CaseInsensitiveDict, LazyNodeLookup
from loki.pragma_utils import (
attach_pragmas, process_dimension_pragmas, detach_pragmas, pragmas_attached
Expand All @@ -55,7 +55,7 @@ def parse_fparser_file(filename):
return parse_fparser_source(source=fcode)


@Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s')
@Timer(logger=perf, text=lambda s: f'[Loki::FP] Executed parse_fparser_source in {s:.2f}s')
def parse_fparser_source(source):
"""
Generate a parse tree from string
Expand All @@ -77,7 +77,7 @@ def parse_fparser_source(source):
return f2008_parser(reader)


@Timer(logger=debug, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s')
@Timer(logger=perf, text=lambda s: f'[Loki::FP] Executed parse_fparser_ast in {s:.2f}s')
def parse_fparser_ast(ast, raw_source, pp_info=None, definitions=None, scope=None):
"""
Generate an internal IR from fparser parse tree
Expand Down
2 changes: 1 addition & 1 deletion loki/frontend/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,4 +519,4 @@ def join_source_list(source_list):
newlines = 0
string += '\n' * newlines + source.string
lines[1] = source.lines[1] if source.lines[1] else lines[1] + newlines + source.string.count('\n')
return Source(lines, string, source_list[0].file)
return Source(tuple(lines), string, source_list[0].file)
128 changes: 57 additions & 71 deletions loki/frontend/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@
from enum import IntEnum
from pathlib import Path
import codecs
from codetiming import Timer
from more_itertools import split_after

from loki.visitors import NestedTransformer, FindNodes, PatternFinder, SequenceFinder
from loki.visitors import (
NestedTransformer, FindNodes, PatternFinder, Transformer
)
from loki.ir import (
Assignment, Comment, CommentBlock, VariableDeclaration, ProcedureDeclaration,
Loop, Intrinsic, Pragma
)
from loki.frontend.source import Source
from loki.logging import warning
from loki.frontend.source import join_source_list
from loki.logging import warning, perf
from loki.tools import group_by_class, replace_windowed


__all__ = [
'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX',
'inline_comments', 'cluster_comments', 'read_file',
'combine_multiline_pragmas', 'sanitize_ir'
'Frontend', 'OFP', 'OMNI', 'FP', 'REGEX', 'inline_comments',
'ClusterCommentTransformer', 'read_file',
'CombineMultilinePragmasTransformer', 'sanitize_ir'
]


Expand Down Expand Up @@ -65,28 +71,26 @@ def inline_comments(ir):
return NestedTransformer(mapper, invalidate_source=False).visit(ir)


def cluster_comments(ir):
class ClusterCommentTransformer(Transformer):
"""
Cluster comments into comment blocks
Combines consecutive sets of :any:`Comment` into a :any:`CommentBlock`.
"""
comment_mapper = {}
comment_groups = SequenceFinder(node_type=Comment).visit(ir)
for comments in comment_groups:
# Build a CommentBlock and map it to first comment
# and map remaining comments to None for removal
if all(c.source is not None for c in comments):
if all(c.source.string is not None for c in comments):
string = '\n'.join(c.source.string for c in comments)
else:
string = None
lines = {l for c in comments for l in c.source.lines if l is not None}
lines = (min(lines), max(lines))
source = Source(lines=lines, string=string, file=comments[0].source.file)
else:
source = None
block = CommentBlock(comments, label=comments[0].label, source=source)
comment_mapper[comments] = block
return NestedTransformer(comment_mapper, invalidate_source=False).visit(ir)

def visit_tuple(self, o, **kwargs):
"""
Find groups of :any:`Comment` and inject into the tuple.
"""
cgroups = group_by_class(o, Comment)
for group in cgroups:
# Combine the group into a CommentBlock
source = join_source_list(tuple(p.source for p in group))
block = CommentBlock(comments=group, label=group[0].label, source=source)
o = replace_windowed(o, group, subs=(block,))

# Then recurse over the new nodes
return tuple(self.visit(i, **kwargs) for i in o)

visit_list = visit_tuple


def inline_labels(ir):
Expand Down Expand Up @@ -135,51 +139,33 @@ def read_file(file_path):
return source


def combine_multiline_pragmas(ir):
class CombineMultilinePragmasTransformer(Transformer):
"""
Combine multiline pragmas into single pragma nodes
Combine multiline :any:`Pragma` nodes into single ones.
"""
pragma_mapper = {}
pragma_groups = SequenceFinder(node_type=Pragma).visit(ir)
for pragma_list in pragma_groups:
collected_pragmas = []
for pragma in pragma_list:
if not collected_pragmas:
if pragma.content.rstrip().endswith('&'):
# This is the beginning of a multiline pragma
collected_pragmas = [pragma]
else:
# This is the continuation of a multiline pragma
collected_pragmas += [pragma]

if pragma.keyword != collected_pragmas[0].keyword:
raise RuntimeError('Pragma keyword mismatch after line continuation: ' +
f'{collected_pragmas[0].keyword} != {pragma.keyword}')

if not pragma.content.rstrip().endswith('&'):
# This is the last line of a multiline pragma
content = [p.content.strip()[:-1].rstrip() for p in collected_pragmas[:-1]]
content = ' '.join(content) + ' ' + pragma.content.strip()

if all(p.source is not None for p in collected_pragmas):
if all(p.source.string is not None for p in collected_pragmas):
string = '\n'.join(p.source.string for p in collected_pragmas)
else:
string = None
lines = (collected_pragmas[0].source.lines[0], collected_pragmas[-1].source.lines[1])
source = Source(lines=lines, string=string, file=pragma.source.file)
else:
source = None

new_pragma = Pragma(keyword=pragma.keyword, content=content, source=source)
pragma_mapper[collected_pragmas[0]] = new_pragma
pragma_mapper.update({p: None for p in collected_pragmas[1:]})

collected_pragmas = []

return NestedTransformer(pragma_mapper, invalidate_source=False).visit(ir)

def visit_tuple(self, o, **kwargs):
"""
Finds multi-line pragmas and combines them in-place.
"""
pgroups = group_by_class(o, Pragma)

for group in pgroups:
# Separate sets of consecutive multi-line pragmas
pred = lambda p: not p.content.rstrip().endswith('&') # pylint: disable=unnecessary-lambda-assignment
for pragmaset in split_after(group, pred=pred):
# Combine into a single pragma and add to map
source = join_source_list(tuple(p.source for p in pragmaset))
content = ' '.join(p.content.rstrip(' &') for p in pragmaset)
new_pragma = Pragma(
keyword=pragmaset[0].keyword, content=content, source=source
)
o = replace_windowed(o, pragmaset, subs=(new_pragma,))

return tuple(self.visit(i, **kwargs) for i in o)


@Timer(logger=perf, text=lambda s: f'[Loki::Frontend] Executed sanitize_ir in {s:.2f}s')
def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None):
"""
Utility function to sanitize internal representation after creating it
Expand All @@ -189,8 +175,8 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None):
the following operations:
* :any:`inline_comments` to attach inline-comments to IR nodes
* :any:`cluster_comments` to combine multi-line comments into :any:`CommentBlock`
* :any:`combine_multiline_pragmas` to combine multi-line pragmas into a
* :any:`ClusterCommentTransformer` to combine multi-line comments into :any:`CommentBlock`
* :any:`CombineMultilinePragmasTransformer` to combine multi-line pragmas into a
single node
Parameters
Expand All @@ -213,12 +199,12 @@ def sanitize_ir(_ir, frontend, pp_registry=None, pp_info=None):

# Perform some minor sanitation tasks
_ir = inline_comments(_ir)
_ir = cluster_comments(_ir)
_ir = ClusterCommentTransformer(invalidate_source=False).visit(_ir)

if frontend in (OMNI, OFP):
_ir = inline_labels(_ir)

if frontend in (FP, OFP):
_ir = combine_multiline_pragmas(_ir)
_ir = CombineMultilinePragmasTransformer(invalidate_source=False).visit(_ir)

return _ir
43 changes: 42 additions & 1 deletion loki/tools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from subprocess import run, PIPE, STDOUT, CalledProcessError
from contextlib import contextmanager
from pathlib import Path
from itertools import groupby
from more_itertools import replace

try:
import yaml
Expand All @@ -31,7 +33,7 @@
'execute', 'CaseInsensitiveDict', 'strip_inline_comments',
'binary_insertion_sort', 'cached_func', 'optional', 'LazyNodeLookup',
'yaml_include_constructor', 'auto_post_mortem_debugger', 'set_excepthook',
'timeout', 'WeakrefProperty'
'timeout', 'WeakrefProperty', 'group_by_class', 'replace_windowed'
]


Expand Down Expand Up @@ -628,3 +630,42 @@ def __set__(self, obj, value):
obj.__dict__[self._name] = value
else:
setattr(obj, self._name, value)


def group_by_class(iterable, klass):
"""
Find groups of consecutive instances of the same type with more
than one element.
Parameters
----------
iterable : iterable
Input iterable from which to extract groups
klass : type
Type by which to group elements in the given iterable
"""
groups = tuple(
tuple(g) for k, g in groupby(iterable, key=lambda x: x.__class__)
if k == klass
)
return tuple(g for g in groups if len(g) > 1)


def replace_windowed(iterable, group, subs):
"""
Replace a set of consecutive elements in a larger iterable.
Parameters
----------
iterable : iterable
Input iterable in which to replace elements
group : iterable
Group of elements to replace in ``iterable``
subs : any
Replacement for ``group`` in ``iterable``
"""
group = as_tuple(group)
return tuple(replace(
iterable, pred=lambda *args: args == group,
substitutes=as_tuple(subs), window_size=len(group)
))
9 changes: 2 additions & 7 deletions loki/visitors/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
"""
Visitor classes for transforming the IR
"""
from more_itertools import windowed

from loki.ir import Node, Conditional, ScopedNode
from loki.tools import flatten, is_iterable, as_tuple
from loki.tools import flatten, is_iterable, as_tuple, replace_windowed
from loki.visitors.visitor import Visitor


Expand Down Expand Up @@ -137,10 +135,7 @@ def _inject_handle(nodes, i, old, new):

for k, handle in self.mapper.items():
if is_iterable(k):
w = list(windowed(o, len(k)))
if k in w:
i = list(w).index(k)
o = o[:i] + as_tuple(handle) + o[i+len(k):]
o = replace_windowed(o, k, subs=handle)
if k in o and is_iterable(handle):
# Replace k by the iterable that is provided by handle
o, i = _inject_handle(o, 0, k, handle)
Expand Down
62 changes: 60 additions & 2 deletions tests/test_frontends.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
Deallocation, Associate, BasicType, OMNI, OFP, FP, Enumeration,
config, REGEX, Sourcefile, Import, RawSource, CallStatement,
RegexParserClass, ProcedureType, DerivedType, Comment, Pragma,
PreprocessorDirective, config_override, Section
PreprocessorDirective, config_override, Section, CommentBlock
)
from loki.expression import symbols as sym

Expand Down Expand Up @@ -1686,15 +1686,73 @@ def test_pragma_line_continuation(frontend):
!$ACC& PRESENT(ZRDG_LCVQ,ZFLU_QSATS,ZRDG_CVGQ) &
!$ACC& PRIVATE (JBLK) &
!$ACC& VECTOR_LENGTH (YDCPG_OPTS%KLON)
!$ACC SEQUENTIAL
END SUBROUTINE TOTO
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

pragmas = FindNodes(Pragma).visit(routine.body)
assert len(pragmas) == 1

assert len(pragmas) == 2
assert pragmas[0].keyword == 'ACC'
assert 'PARALLEL' in pragmas[0].content
assert 'PRESENT' in pragmas[0].content
assert 'PRIVATE' in pragmas[0].content
assert 'VECTOR_LENGTH' in pragmas[0].content
assert pragmas[1].content == 'SEQUENTIAL'

# Check that source object was generated right
assert pragmas[0].source
assert pragmas[0].source.lines == (8, 8) if frontend == OMNI else (8, 11)
assert pragmas[1].source
assert pragmas[1].source.lines == (12, 12)


@pytest.mark.parametrize('frontend', available_frontends())
def test_comment_block_clustering(frontend):
"""
Test that multiple :any:`Comment` nodes into a :any:`CommentBlock`.
"""
fcode = """
subroutine test_comment_block(a, b)
! What is this?
! Ohhh, ... a docstring?
real, intent(inout) :: a, b
a = a + 1.0
! Never gonna
b = b + 2
! give you
! up...
a = a + b
! Shut up, ...
! Rick!
end subroutine test_comment_block
"""
routine = Subroutine.from_source(fcode, frontend=frontend)

comments = FindNodes(Comment).visit(routine.spec)
assert len(comments) == 0
blocks = FindNodes(CommentBlock).visit(routine.spec)
assert len(blocks) == 0

assert isinstance(routine.docstring[0], CommentBlock)
assert len(routine.docstring[0].comments) == 2
assert routine.docstring[0].comments[0].text == '! What is this?'
assert routine.docstring[0].comments[1].text == '! Ohhh, ... a docstring?'

comments = FindNodes(Comment).visit(routine.body)
assert len(comments) == 2 if frontend == FP else 1
assert comments[-1].text == '! Never gonna'

blocks = FindNodes(CommentBlock).visit(routine.body)
assert len(blocks) == 2
assert len(blocks[0].comments) == 3 if frontend == FP else 2
assert blocks[0].comments[0].text == '! give you'
assert blocks[0].comments[1].text == '! up...'

assert len(blocks[1].comments) == 2
assert blocks[1].comments[0].text == '! Shut up, ...'
assert blocks[1].comments[1].text == '! Rick!'
Loading

0 comments on commit 22abc5c

Please sign in to comment.