Skip to content

Commit

Permalink
Rename kernels argument in dependency trafos
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal authored and MichaelSt98 committed Jan 10, 2025
1 parent 98366f4 commit 1498e0d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
16 changes: 8 additions & 8 deletions loki/transformations/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,16 @@ class DuplicateKernel(Transformation):

reverse_traversal = True

def __init__(self, kernels=None, duplicate_suffix='duplicated',
def __init__(self, duplicate_kernels=None, duplicate_suffix='duplicated',
duplicate_module_suffix=None):
self.suffix = duplicate_suffix
self.module_suffix = duplicate_module_suffix or duplicate_suffix
self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels))
self.duplicate_kernels = tuple(kernel.lower() for kernel in as_tuple(duplicate_kernels))

def _create_duplicate_items(self, successors, item_factory, config):
new_items = ()
for item in successors:
if item.local_name in self.kernels:
if item.local_name in self.duplicate_kernels:
# Determine new item name
scope_name = item.scope_name
local_name = f'{item.local_name}{self.suffix}'
Expand Down Expand Up @@ -72,7 +72,7 @@ def transform_subroutine(self, routine, **kwargs):
new_imports = []
for call in FindNodes(ir.CallStatement).visit(routine.body):
call_name = str(call.name).lower()
if call_name in self.kernels:
if call_name in self.duplicate_kernels:
# Duplicate the call
new_call_name = f'{call_name}{self.suffix}'.lower()
new_item = new_dependencies[new_call_name]
Expand Down Expand Up @@ -102,13 +102,13 @@ class RemoveKernel(Transformation):

creates_items = True

def __init__(self, kernels=None):
self.kernels = tuple(kernel.lower() for kernel in as_tuple(kernels))
def __init__(self, remove_kernels=None):
self.remove_kernels = tuple(kernel.lower() for kernel in as_tuple(remove_kernels))

def transform_subroutine(self, routine, **kwargs):
call_map = {
call: None for call in FindNodes(ir.CallStatement).visit(routine.body)
if str(call.name).lower() in self.kernels
if str(call.name).lower() in self.remove_kernels
}
routine.body = Transformer(call_map).visit(routine.body)

Expand All @@ -118,5 +118,5 @@ def plan_subroutine(self, routine, **kwargs):
successors = as_tuple(kwargs.get('successors'))
item.plan_data.setdefault('removed_dependencies', ())
item.plan_data['removed_dependencies'] += tuple(
child for child in successors if child.local_name in self.kernels
child for child in successors if child.local_name in self.remove_kernels
)
14 changes: 7 additions & 7 deletions loki/transformations/tests/test_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def test_dependency_duplicate_plan(tmp_path, frontend, suffix, module_suffix, co
)

pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
kernels=('kernel',), duplicate_suffix=suffix,
duplicate_kernels=('kernel',), duplicate_suffix=suffix,
duplicate_module_suffix=module_suffix)

plan_file = tmp_path/'plan.cmake'
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_dependency_duplicate_trafo(tmp_path, frontend, suffix, module_suffix, c
)

pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
kernels=('kernel',), duplicate_suffix=suffix,
duplicate_kernels=('kernel',), duplicate_suffix=suffix,
duplicate_module_suffix=module_suffix)

scheduler.process(pipeline)
Expand Down Expand Up @@ -231,7 +231,7 @@ def test_dependency_remove(tmp_path, frontend, config):
frontend=frontend, xmods=[tmp_path]
)
pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation),
kernels=('kernel',))
remove_kernels=('kernel',))

plan_file = tmp_path/'plan.cmake'
root_path = tmp_path
Expand Down Expand Up @@ -267,7 +267,7 @@ def test_dependency_duplicate_plan_no_module(tmp_path, frontend, suffix, module_
)

pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
kernels=('kernel',), duplicate_suffix=suffix,
duplicate_kernels=('kernel',), duplicate_suffix=suffix,
duplicate_module_suffix=module_suffix)

plan_file = tmp_path/'plan.cmake'
Expand Down Expand Up @@ -306,7 +306,7 @@ def test_dependency_duplicate_trafo_no_module(tmp_path, frontend, suffix, module
)

pipeline = Pipeline(classes=(DuplicateKernel, FileWriteTransformation),
kernels=('kernel',), duplicate_suffix=suffix,
duplicate_kernels=('kernel',), duplicate_suffix=suffix,
duplicate_module_suffix=module_suffix)

scheduler.process(pipeline)
Expand Down Expand Up @@ -339,7 +339,7 @@ def test_dependency_remove_plan_no_module(tmp_path, frontend, config, full_parse
frontend=frontend, xmods=[tmp_path], full_parse=full_parse
)
pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation),
kernels=('kernel',))
remove_kernels=('kernel',))

plan_file = tmp_path/'plan.cmake'
scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
Expand All @@ -366,7 +366,7 @@ def test_dependency_remove_trafo_no_module(tmp_path, frontend, config):
frontend=frontend, xmods=[tmp_path]
)
pipeline = Pipeline(classes=(RemoveKernel, FileWriteTransformation),
kernels=('kernel',))
remove_kernels=('kernel',))

scheduler.process(pipeline)
driver = scheduler["#driver"].ir
Expand Down

0 comments on commit 1498e0d

Please sign in to comment.