Skip to content

Commit

Permalink
Sanitise: Use iterator, when matching free symbols in assoc resolve
Browse files Browse the repository at this point in the history
When using dict-mapping to match symbols, the range keys might be
`:`, which alias and mean we'd miss susequent `:` matches. The test
has been updated accordingly.
  • Loading branch information
mlange05 committed Dec 11, 2024
1 parent c055d93 commit bed1ae8
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
12 changes: 10 additions & 2 deletions loki/transformations/sanitise/associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,17 @@ def _match_range_indices(expressions, indices):
free_symbols = tuple(e for e in expressions if isinstance(e, sym.RangeIndex))
if any(s.lower not in (None, 1) for s in free_symbols):
warning('WARNING: Bounds shifts through association is currently not supported')
symbol_map = dict(zip(free_symbols, indices))

return tuple(symbol_map.get(e, e) for e in expressions)
if len(free_symbols) == len(indices):
# If the provided indices are enough to bind free symbols,
# we match them in sequence.
it = iter(indices)
return tuple(
next(it) if isinstance(e, sym.RangeIndex) else e
for e in expressions
)

return expressions

def map_scalar(self, expr, *args, **kwargs):
# Skip unscoped expressions
Expand Down
19 changes: 13 additions & 6 deletions loki/transformations/sanitise/tests/test_associates.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,18 +156,23 @@ def test_transform_associates_array_slices(frontend):
Test the resolution of associated array slices.
"""
fcode = """
subroutine transform_associates_slices(arr2d)
subroutine transform_associates_slices(arr2d, arr3d)
use some_module, only: some_obj, another_routine
implicit none
real, intent(inout) :: arr2d(:,:)
integer :: i
real, intent(inout) :: arr2d(:,:), arr3d(:,:,:)
integer :: i, j
integer, parameter :: idx_a = 2
integer, parameter :: idx_c = 3
associate (a => arr2d(:, 1), b=>arr2d(:, idx_a) )
associate (a => arr2d(:, 1), b=>arr2d(:, idx_a), &
& c => arr3d(:,:,idx_c) )
b(:) = 42.0
do i=1, 5
a(i) = b(i+2)
call another_routine(i, a(2:4), b)
do j=1, 7
c(i, j) = c(i, j) + b(j)
end do
end do
end associate
end subroutine transform_associates_slices
Expand All @@ -177,7 +182,7 @@ def test_transform_associates_array_slices(frontend):
assert len(FindNodes(ir.Associate).visit(routine.body)) == 1
assert len(FindNodes(ir.CallStatement).visit(routine.body)) == 1
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 2
assert len(assigns) == 3
calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
assert calls[0].arguments[1] == 'a(2:4)'
Expand All @@ -188,10 +193,12 @@ def test_transform_associates_array_slices(frontend):

assert len(FindNodes(ir.Associate).visit(routine.body)) == 0
assigns = FindNodes(ir.Assignment).visit(routine.body)
assert len(assigns) == 2
assert len(assigns) == 3
assert assigns[0].lhs == 'arr2d(:, idx_a)'
assert assigns[1].lhs == 'arr2d(i, 1)'
assert assigns[1].rhs == 'arr2d(i+2, idx_a)'
assert assigns[2].lhs == 'arr3d(i, j, idx_c)'
assert assigns[2].rhs == 'arr3d(i, j, idx_c) + arr2d(j, idx_a)'

calls = FindNodes(ir.CallStatement).visit(routine.body)
assert len(calls) == 1
Expand Down

0 comments on commit bed1ae8

Please sign in to comment.