From bed1ae8faf1cff9f4b00fcd5af266782160e2446 Mon Sep 17 00:00:00 2001 From: Michael Lange Date: Wed, 11 Dec 2024 10:52:57 +0000 Subject: [PATCH] Sanitise: Use iterator, when matching free symbols in assoc resolve When using dict-mapping to match symbols, the range keys might be `:`, which alias and mean we'd miss susequent `:` matches. The test has been updated accordingly. --- loki/transformations/sanitise/associates.py | 12 ++++++++++-- .../sanitise/tests/test_associates.py | 19 +++++++++++++------ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/loki/transformations/sanitise/associates.py b/loki/transformations/sanitise/associates.py index 5ccfc0ecb..9e50c9e1b 100644 --- a/loki/transformations/sanitise/associates.py +++ b/loki/transformations/sanitise/associates.py @@ -121,9 +121,17 @@ def _match_range_indices(expressions, indices): free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex)) if any(s.lower not in (None, 1) for s in free_symbols): warning('WARNING: Bounds shifts through association is currently not supported') - symbol_map = dict(zip(free_symbols, indices)) - return tuple(symbol_map.get(e, e) for e in expressions) + if len(free_symbols) == len(indices): + # If the provided indices are enough to bind free symbols, + # we match them in sequence. + it = iter(indices) + return tuple( + next(it) if isinstance(e, sym.RangeIndex) else e + for e in expressions + ) + + return expressions def map_scalar(self, expr, *args, **kwargs): # Skip unscoped expressions diff --git a/loki/transformations/sanitise/tests/test_associates.py b/loki/transformations/sanitise/tests/test_associates.py index 085c7016c..33fb22cbc 100644 --- a/loki/transformations/sanitise/tests/test_associates.py +++ b/loki/transformations/sanitise/tests/test_associates.py @@ -156,18 +156,23 @@ def test_transform_associates_array_slices(frontend): Test the resolution of associated array slices. """ fcode = """ -subroutine transform_associates_slices(arr2d) +subroutine transform_associates_slices(arr2d, arr3d) use some_module, only: some_obj, another_routine implicit none - real, intent(inout) :: arr2d(:,:) - integer :: i + real, intent(inout) :: arr2d(:,:), arr3d(:,:,:) + integer :: i, j integer, parameter :: idx_a = 2 + integer, parameter :: idx_c = 3 - associate (a => arr2d(:, 1), b=>arr2d(:, idx_a) ) + associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), & + & c => arr3d(:,:,idx_c) ) b(:) = 42.0 do i=1, 5 a(i) = b(i+2) call another_routine(i, a(2:4), b) + do j=1, 7 + c(i, j) = c(i, j) + b(j) + end do end do end associate end subroutine transform_associates_slices @@ -177,7 +182,7 @@ def test_transform_associates_array_slices(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 1 assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 2 + assert len(assigns) == 3 calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1 assert calls[0].arguments[1] == 'a(2:4)' @@ -188,10 +193,12 @@ def test_transform_associates_array_slices(frontend): assert len(FindNodes(ir.Associate).visit(routine.body)) == 0 assigns = FindNodes(ir.Assignment).visit(routine.body) - assert len(assigns) == 2 + assert len(assigns) == 3 assert assigns[0].lhs == 'arr2d(:, idx_a)' assert assigns[1].lhs == 'arr2d(i, 1)' assert assigns[1].rhs == 'arr2d(i+2, idx_a)' + assert assigns[2].lhs == 'arr3d(i, j, idx_c)' + assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)' calls = FindNodes(ir.CallStatement).visit(routine.body) assert len(calls) == 1