diff --git a/loki/visitors/ir_graph.py b/loki/visitors/ir_graph.py index e645d3907..8b6fd6ed9 100644 --- a/loki/visitors/ir_graph.py +++ b/loki/visitors/ir_graph.py @@ -184,9 +184,33 @@ def __add_node(self, node, **kwargs): A list of a tuple of a node and potentially a edge information """ label = kwargs.get("label", "") + if label == "": label = self.format_node(repr(node)) + try: + live_symbols = "live: [" + ", ".join( + str(symbol) for symbol in node.live_symbols + ) + defines_symbols = "defines: [" + ", ".join( + str(symbol) for symbol in node.defines_symbols + ) + uses_symbols = "uses: [" + ", ".join( + str(symbol) for symbol in node.uses_symbols + ) + label = self.format_line( + label, + "\n", + live_symbols, + "], ", + defines_symbols, + "], ", + uses_symbols, + "]", + ) + except (RuntimeError, KeyError, AttributeError) as _: + pass + shape = kwargs.get("shape", "oval") node_key = str(id(node)) @@ -321,8 +345,7 @@ def visit_Conditional(self, o, **kwargs): return node_edge_info -def ir_graph( - ir, show_comments=False, show_expressions=False, linewidth=40, symgen=str): +def ir_graph(ir, show_comments=False, show_expressions=False, linewidth=40, symgen=str): """ Pretty-print the given IR using :class:`GraphCollector`. @@ -342,8 +365,12 @@ def ir_graph( log = "[Loki::Graph Visualization] Created graph visualization in {:.2f}s" with Timer(text=log): - graph_representation = GraphCollector(show_comments, show_expressions, linewidth, symgen) - node_edge_info = [item for item in graph_representation.visit(ir) if item is not None] + graph_representation = GraphCollector( + show_comments, show_expressions, linewidth, symgen + ) + node_edge_info = [ + item for item in graph_representation.visit(ir) if item is not None + ] graph = Digraph() graph.attr(rankdir="LR") diff --git a/tests/test_ir_graph.py b/tests/test_ir_graph.py index 7958d3d88..ac7e14c1d 100644 --- a/tests/test_ir_graph.py +++ b/tests/test_ir_graph.py @@ -11,6 +11,9 @@ from conftest import graphviz_present from loki import Sourcefile from loki.visitors.ir_graph import ir_graph, GraphCollector +from loki.visitors import FindNodes +from loki.analyse import dataflow_analysis_attached +from loki.ir import Node @pytest.fixture(scope="module", name="here") @@ -197,7 +200,9 @@ def test_graph_collector_node_edge_count_only( graph_collector = GraphCollector( show_comments=show_comments, show_expressions=show_expressions ) - node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None] + node_edge_info = [ + item for item in graph_collector.visit(source.ir) if item is not None + ] node_names = [name for (name, _) in get_property(node_edge_info, "name")] node_labels = [label for (label, _) in get_property(node_edge_info, "label")] @@ -224,7 +229,9 @@ def test_graph_collector_detail(here, test_file): source = Sourcefile.from_file(here / test_file) graph_collector = GraphCollector() - node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None] + node_edge_info = [ + item for item in graph_collector.visit(source.ir) if item is not None + ] node_names = [name for (name, _) in get_property(node_edge_info, "name")] node_labels = [label for (label, _) in get_property(node_edge_info, "label")] @@ -252,7 +259,9 @@ def test_graph_collector_maximum_label_length(here, test_file, linewidth): graph_collector = GraphCollector( show_comments=True, show_expressions=True, linewidth=linewidth ) - node_edge_info = [item for item in graph_collector.visit(source.ir) if item is not None] + node_edge_info = [ + item for item in graph_collector.visit(source.ir) if item is not None + ] node_labels = [label for (label, _) in get_property(node_edge_info, "label")] for label in node_labels: @@ -309,3 +318,38 @@ def test_ir_graph_writes_correct_graphs(here, test_file): for node, label in zip(node_ids, found_labels): assert solution["node_labels"][node[0]] == label[0] + + +@pytest.mark.parametrize("test_file", test_files) +def test_ir_graph_dataflow_analysis_attached(here, test_file): + source = Sourcefile.from_file(here / test_file) + + def find_lives_defines_uses(text): + # Regular expression pattern to match content within square brackets after 'live:', 'defines:', and 'uses:' + pattern = r"live:\s*\[([^\]]*?)\],\s*defines:\s*\[([^\]]*?)\],\s*uses:\s*\[([^\]]*?)\]" + matches = re.search(pattern, text) + assert matches + + def remove_spaces_and_newlines(text): + return text.replace(" ", "").replace("\n", "") + + def disregard_empty_strings(elements): + return set(element for element in elements if element != "") + + def apply_conversion(text): + return disregard_empty_strings(remove_spaces_and_newlines(text).split(",")) + + return ( + apply_conversion(matches.group(1)), + apply_conversion(matches.group(2)), + apply_conversion(matches.group(3)), + ) + + for routine in source.all_subroutines: + with dataflow_analysis_attached(routine): + for node in FindNodes(Node).visit(routine.body): + node_info, _ = GraphCollector(show_comments=True).visit(node)[0] + lives, defines, uses = find_lives_defines_uses(node_info["label"]) + assert node.live_symbols == set(lives) + assert node.uses_symbols == set(uses) + assert node.defines_symbols == set(defines)