Skip to content

Commit

Permalink
#2845 always inline symbols into Routine symbol table.
Browse files Browse the repository at this point in the history
  • Loading branch information
arporter committed Jan 14, 2025
1 parent 234101b commit b78c558
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 26 deletions.
50 changes: 25 additions & 25 deletions src/psyclone/psyir/transformations/inline_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,15 @@ def apply(self, node, options=None):
:type options: Optional[Dict[str, Any]]
:param bool options["force"]: whether or not to permit the inlining
of Routines containing CodeBlocks. Default is False.
:raises InternalError: if the merge of the symbol tables fails.
In theory this should never happen because validate() should
catch such a situation.
'''
self.validate(node, options)
# The table associated with the scoping region holding the Call.
table = node.scope.symbol_table
table = node.ancestor(Routine).symbol_table
# Find the routine to be inlined.
orig_routine = node.get_callees()[0]

Expand All @@ -161,8 +166,14 @@ def apply(self, node, options=None):

# Shallow copy the symbols from the routine into the table at the
# call site.
table.merge(routine_table,
symbols_to_skip=routine_table.argument_list[:])
try:
table.merge(routine_table,
symbols_to_skip=routine_table.argument_list[:])
except SymbolError as err:
raise InternalError(
f"Error copying routine symbols to call site. This should "
f"have been caught by the validate() method. Original error "
f"was {err}")

# When constructing new references to replace references to formal
# args, we need to know whether any of the actual arguments are array
Expand All @@ -184,11 +195,6 @@ def apply(self, node, options=None):
for ref in refs[:]:
self._replace_formal_arg(ref, node, formal_args)

# Store the Routine level symbol table and node's current scope
# so we can merge symbol tables later if required.
ancestor_table = node.ancestor(Routine).scope.symbol_table
scope = node.scope

# Copy the nodes from the Routine into the call site.
# TODO #924 - while doing this we should ensure that any References
# to common/shared Symbols in the inlined code are updated to point
Expand Down Expand Up @@ -221,19 +227,6 @@ def apply(self, node, options=None):
idx += 1
parent.addchild(child, idx)

# If the scope we merged the inlined function's symbol table into
# is not a Routine scope then we now merge that symbol table into
# the ancestor Routine. This avoids issues like #2424 when
# applying ParallelLoopTrans to loops containing inlined calls.
if ancestor_table is not scope.symbol_table:
try:
ancestor_table.merge(scope.symbol_table)
except SymbolError as err:
raise InternalError("No escape") from err
replacement = type(scope.symbol_table)()
scope.symbol_table.detach()
replacement.attach(scope)

def _replace_formal_arg(self, ref, call_node, formal_args):
'''
Recursively combines any References to formal arguments in the supplied
Expand Down Expand Up @@ -602,6 +595,7 @@ def validate(self, node, options=None):
and the 'force' option is not True.
:raises TransformationError: if the called routine has a named
argument.
:raises TransformationError: if the call-site is not within a Routine.
:raises TransformationError: if any of the variables declared within
the called routine are of UnknownInterface.
:raises TransformationError: if any of the variables declared within
Expand Down Expand Up @@ -636,6 +630,13 @@ def validate(self, node, options=None):
f"Cannot inline an IntrinsicCall ('{node.routine.name}')")
name = node.routine.name

# The call site must be within a Routine (i.e. not detached)
parent_routine = node.ancestor(Routine)
if not parent_routine:
raise TransformationError(
f"Routine '{name}' cannot be inlined because the call site "
f"('{node.debug_string().strip()}') is not inside a Routine.")

# Check that we can find the source of the routine being inlined.
# TODO #924 allow for multiple routines (interfaces).
try:
Expand Down Expand Up @@ -664,8 +665,8 @@ def validate(self, node, options=None):
# CodeBlocks to be included.
raise TransformationError(
f"Routine '{name}' contains one or more CodeBlocks and "
"therefore cannot be inlined. (If you are confident that "
"the code may safely be inlined despite this then use "
f"therefore cannot be inlined. (If you are confident that "
f"the code may safely be inlined despite this then use "
"`options={'force': True}` to override.)")

# Support for routines with named arguments is not yet implemented.
Expand All @@ -676,8 +677,7 @@ def validate(self, node, options=None):
f"Routine '{routine.name}' cannot be inlined because it "
f"has a named argument '{arg}' (TODO #924).")

parent_routine = node.ancestor(Routine)
table = parent_routine.symbol_table # node.scope.symbol_table
table = parent_routine.symbol_table
routine_table = routine.symbol_table

for sym in routine_table.datasymbols:
Expand Down
48 changes: 47 additions & 1 deletion src/psyclone/tests/psyir/transformations/inline_trans_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import pytest

from psyclone.configuration import Config
from psyclone.errors import InternalError
from psyclone.psyir.nodes import Call, IntrinsicCall, Reference, Routine, Loop
from psyclone.psyir.symbols import (
AutomaticInterface, DataSymbol, UnresolvedType)
Expand Down Expand Up @@ -1271,6 +1272,36 @@ def test_apply_callsite_rename_container(fortran_reader, fortran_writer):
" i = i * a_mod_1\n" in output)


def test_apply_internal_error(fortran_reader, monkeypatch):
'''
Test that we raise the expected error in apply if we find a situation that
should have been caught by validate.
'''
code = (
"module test_mod\n"
"contains\n"
" subroutine run_it()\n"
" use some_mod, only: a_clash\n"
" integer :: i\n"
" i = 10\n"
" call sub(i)\n"
" end subroutine run_it\n"
" subroutine sub(idx)\n"
" use other_mod, only: a_clash\n"
" integer :: idx\n"
" idx = idx + trouble\n"
" end subroutine sub\n"
"end module test_mod\n")
psyir = fortran_reader.psyir_from_source(code)
call = psyir.walk(Call)[0]
inline_trans = InlineTrans()
monkeypatch.setattr(inline_trans, "validate", lambda _a, _b: None)
with pytest.raises(InternalError) as err:
inline_trans.apply(call)
assert ("Error copying routine symbols to call site. This should have "
"been caught" in str(err.value))


def test_validate_non_local_import(fortran_reader):
'''Test that we reject the case where the routine to be
inlined accesses a symbol from an import in its parent container.'''
Expand Down Expand Up @@ -2164,7 +2195,22 @@ def test_validate_named_arg(fortran_reader):
f"end module inline_mod\n")


def test_apply_merges_symbol_table_with_routine(fortran_reader):
def test_validate_call_within_routine(fortran_reader):
'''
Check that validate raises the expected error if the call is not within
a Routine.
'''
psyir = fortran_reader.psyir_from_source(CALL_IN_SUB_USE)
call = psyir.walk(Call)[0]
inline_trans = InlineTrans()
with pytest.raises(TransformationError) as err:
inline_trans.validate(call.detach())
assert ("Routine 'sub' cannot be inlined because the call site ('call "
"sub(a)') is not inside a Routine" in str(err.value))


def test_apply_merges_symbol_table_with_routine(fortran_reader,
fortran_writer):
'''
Check that the apply method merges the inlined function's symbol table to
the containing Routine when the call node is inside a child ScopingNode.
Expand Down

0 comments on commit b78c558

Please sign in to comment.