Skip to content

Commit

Permalink
Transform: Add simple Dead Code Elimination to trim code paths
Browse files Browse the repository at this point in the history
  • Loading branch information
mlange05 committed Dec 12, 2023
1 parent 61f41ba commit 72d8eb6
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 5 deletions.
1 change: 1 addition & 0 deletions loki/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@
from loki.transform.transform_hoist_variables import * # noqa
from loki.transform.transform_parametrise import * # noqa
from loki.transform.transform_sequence_association import * # noqa
from loki.transform.transform_dead_code import * # noqa
66 changes: 66 additions & 0 deletions loki/transform/transform_dead_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

"""
Collection of utilities to perform Dead Code Elimination.
"""
from loki.visitors import Transformer
from loki.expression.symbolic import simplify
from loki.tools import flatten, as_tuple


__all__ = ['dead_code_elimination', 'DeadCodeEliminationTransformer']


def dead_code_elimination(routine, use_simplify=True):
"""
Perform Dead Code Elimination on the given :any:`Subroutine` object.
Parameters
----------
routine : :any:`Subroutine`
The subroutine to which to apply dead code elimination.
simplify : boolean
Use :any:`simplify` when evaluating expressions for branch pruning.
"""

transformer = DeadCodeEliminationTransformer(use_simplify=use_simplify)
routine.body = transformer.visit(routine.body)


class DeadCodeEliminationTransformer(Transformer):
"""
:any:`Transformer` class that removes provably unreachable code paths.
The pirmary modification performed is to prune individual code branches
under :any:`Conditional` nodes.
Parameters
----------
simplify : boolean
Use :any:`simplify` when evaluating expressions for branch pruning.
"""

def __init__(self, use_simplify=True, **kwargs):
super().__init__(**kwargs)
self.use_simplify = use_simplify

def visit_Conditional(self, o, **kwargs):
condition = self.visit(o.condition, **kwargs)
body = as_tuple(flatten(as_tuple(self.visit(o.body, **kwargs))))
else_body = as_tuple(flatten(as_tuple(self.visit(o.else_body, **kwargs))))

if self.use_simplify:
condition = simplify(condition)

if condition == 'True':
return body

if condition == 'False':
return else_body

return self._rebuild(o, tuple((condition,) + (body,) + (else_body,)))
9 changes: 6 additions & 3 deletions scripts/loki_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ def cli(debug):
help='Replace array arguments passed as scalars with arrays.')
@click.option('--derive-argument-array-shape/--no-derive-argument-array-shape', default=False,
help="Recursively derive explicit shape dimension for argument arrays")
@click.option('--dead-code-elimination/--no-dead-code-elimination', default=True,
help='Perform dead code elimination, where unreachable branches are trimmed from the code.')
def convert(
mode, config, build, source, header, cpp, directive, include, define, omni_include, xmod,
data_offload, remove_openmp, assume_deviceptr, frontend, trim_vector_sections,
global_var_offload, remove_derived_args, inline_members, inline_marked,
resolve_sequence_association, derive_argument_array_shape
resolve_sequence_association, derive_argument_array_shape, dead_code_elimination
):
"""
Batch-processing mode for Fortran-to-Fortran transformations that
Expand Down Expand Up @@ -212,8 +214,9 @@ def convert(
if mode in ['scc', 'scc-hoist', 'scc-stack']:
# Apply the basic SCC transformation set
scheduler.process( SCCBaseTransformation(
horizontal=horizontal, directive=directive,
inline_members=inline_members, resolve_sequence_association=resolve_sequence_association
horizontal=horizontal, directive=directive, inline_members=inline_members,
resolve_sequence_association=resolve_sequence_association,
dead_code_elimination=dead_code_elimination
))
scheduler.process( SCCDevectorTransformation(
horizontal=horizontal, trim_vector_sections=trim_vector_sections
Expand Down
52 changes: 52 additions & 0 deletions tests/test_transform_dead_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# (C) Copyright 2018- ECMWF.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import pytest

from conftest import available_frontends
from loki import Subroutine, FindNodes, Conditional, Assignment
from loki.transform import dead_code_elimination


@pytest.mark.parametrize('frontend', available_frontends())
def test_transform_dead_code_conditional(frontend):
"""
Test correct elimination of unreachable conditional branches.
"""
fcode = """
subroutine test_dead_code_conditional(a, b)
real(kind=8), intent(inout) :: a, b
logical, intent(in) :: flag
if (flag) then
if (1 == 6) then
a = a + b
else
b = b + 2.0
end if
if (2 == 2) then
b = b + a
else
a = a + 3.0
end if
end if
end subroutine test_dead_code_conditional
"""
routine = Subroutine.from_source(fcode, frontend=frontend)
assert len(FindNodes(Conditional).visit(routine.body)) == 3
assert len(FindNodes(Assignment).visit(routine.body)) == 4

dead_code_elimination(routine)

conditionals = FindNodes(Conditional).visit(routine.body)
assert len(conditionals) == 1
assert conditionals[0].condition == 'flag'
assigns = FindNodes(Assignment).visit(routine.body)
assert len(assigns) == 2
assert assigns[0].lhs == 'b' and assigns[0].rhs == 'b + 2.0'
assert assigns[1].lhs == 'b' and assigns[1].rhs == 'b + a'
12 changes: 10 additions & 2 deletions transformations/transformations/single_column_coalesced.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from loki.expression import symbols as sym
from loki.transform import (
resolve_associates, inline_member_procedures,
inline_marked_subroutines, transform_sequence_association
inline_marked_subroutines, transform_sequence_association,
dead_code_elimination
)
from loki import (
Transformation, FindNodes, Transformer, info,
Expand Down Expand Up @@ -45,11 +46,14 @@ class methods can be called directly.
Enable inlining for subroutines marked with ``!$loki inline``; default: True.
resolve_sequence_association : bool
Replace scalars that are passed to array arguments with array ranges; default: False.
dead_code_elimination : bool
Perform dead code elimination, where unreachable branches are trimmed from the code.
"""

def __init__(
self, horizontal, directive=None, inline_members=False,
inline_marked=True, resolve_sequence_association=False
inline_marked=True, resolve_sequence_association=False,
dead_code_elimination=True
):
self.horizontal = horizontal

Expand All @@ -59,6 +63,7 @@ def __init__(
self.inline_members = inline_members
self.inline_marked = inline_marked
self.resolve_sequence_association = resolve_sequence_association
self.dead_code_elimination = dead_code_elimination

@classmethod
def check_routine_pragmas(cls, routine, directive):
Expand Down Expand Up @@ -321,6 +326,9 @@ def process_kernel(self, routine):
# with the sections we need to do for detecting subroutine calls
resolve_associates(routine)

if self.dead_code_elimination:
dead_code_elimination(routine)

# Resolve WHERE clauses
self.resolve_masked_stmts(routine, loop_variable=v_index)

Expand Down

0 comments on commit 72d8eb6

Please sign in to comment.