Skip to content

Commit

Permalink
Add test for combined duplicate/remove
Browse files Browse the repository at this point in the history
  • Loading branch information
reuterbal authored and MichaelSt98 committed Jan 10, 2025
1 parent 1498e0d commit 5990887
Showing 1 changed file with 99 additions and 0 deletions.
99 changes: 99 additions & 0 deletions loki/transformations/tests/test_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from loki.frontend import available_frontends
from loki.ir import nodes as ir, FindNodes
from loki.tools import as_tuple
from loki.transformations.dependency import (
DuplicateKernel, RemoveKernel
)
Expand Down Expand Up @@ -373,3 +374,101 @@ def test_dependency_remove_trafo_no_module(tmp_path, frontend, config):
assert "#kernel" not in scheduler

assert not FindNodes(ir.CallStatement).visit(driver.body)


@pytest.mark.usefixtures('fcode_as_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('duplicate_kernels,remove_kernels', (
('kernel', 'kernel'), ('kernel', 'kernel_new'), ('kernel', None), (None, 'kernel')
))
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_duplicate_remove_plan(tmp_path, frontend, duplicate_kernels, remove_kernels,
config, full_parse):

scheduler = Scheduler(
paths=[tmp_path], config=SchedulerConfig.from_dict(config),
frontend=frontend, xmods=[tmp_path], full_parse=full_parse
)

expected_items = {'kernel_mod#kernel', '#driver'}
assert {item.name for item in scheduler.items} == expected_items

pipeline = Pipeline(classes=(DuplicateKernel, RemoveKernel, FileWriteTransformation),
duplicate_kernels=duplicate_kernels, duplicate_suffix='_new',
remove_kernels=remove_kernels)

plan_file = tmp_path/'plan.cmake'
scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

for kernel in as_tuple(duplicate_kernels):
for name in list(expected_items):
scope_name, local_name = name.split('#')
if local_name == kernel:
expected_items.add(f'{scope_name}_new#{local_name}_new')

for kernel in as_tuple(remove_kernels):
for name in list(expected_items):
scope_name, local_name = name.split('#')
if local_name == kernel:
expected_items.remove(name)

# Validate Scheduler graph
assert {item.name for item in scheduler.items} == expected_items

# Validate the plan file content
plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
loki_plan = plan_file.read_text()
plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

transformed_items = {name.split('#')[0] or name[1:] for name in expected_items if not name.endswith('_new')}
assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items
assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items
assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'{name.split("#")[0] or name[1:]}.idem' for name in expected_items}


@pytest.mark.usefixtures('fcode_no_module')
@pytest.mark.parametrize('frontend', available_frontends())
@pytest.mark.parametrize('duplicate_kernels,remove_kernels', (
('kernel', 'kernel'), ('kernel', 'kernel_new'), ('kernel', None), (None, 'kernel')
))
@pytest.mark.parametrize('full_parse', (True, False))
def test_dependency_duplicate_remove_plan_no_module(tmp_path, frontend, duplicate_kernels, remove_kernels,
config, full_parse):

scheduler = Scheduler(
paths=[tmp_path], config=SchedulerConfig.from_dict(config),
frontend=frontend, xmods=[tmp_path], full_parse=full_parse
)

expected_items = {'#kernel', '#driver'}
assert {item.name for item in scheduler.items} == expected_items

pipeline = Pipeline(classes=(DuplicateKernel, RemoveKernel, FileWriteTransformation),
duplicate_kernels=duplicate_kernels, duplicate_suffix='_new',
remove_kernels=remove_kernels)

plan_file = tmp_path/'plan.cmake'
scheduler.process(pipeline, proc_strategy=ProcessingStrategy.PLAN)
scheduler.write_cmake_plan(filepath=plan_file, rootpath=tmp_path)

if duplicate_kernels:
expected_items.add(f'#{duplicate_kernels}_new')

if remove_kernels:
expected_items.remove(f'#{remove_kernels}')

# Validate Scheduler graph
assert {item.name for item in scheduler.items} == expected_items

# Validate the plan file content
plan_pattern = re.compile(r'set\(\s*(\w+)\s*(.*?)\s*\)', re.DOTALL)
loki_plan = plan_file.read_text()
plan_dict = {k: v.split() for k, v in plan_pattern.findall(loki_plan)}
plan_dict = {k: {Path(s).stem for s in v} for k, v in plan_dict.items()}

transformed_items = {name[1:] for name in expected_items if not name.endswith('_new')}
assert plan_dict['LOKI_SOURCES_TO_TRANSFORM'] == transformed_items
assert plan_dict['LOKI_SOURCES_TO_REMOVE'] == transformed_items
assert plan_dict['LOKI_SOURCES_TO_APPEND'] == {f'{name[1:]}.idem' for name in expected_items}

0 comments on commit 5990887

Please sign in to comment.