From f6b6af4b4f96fd4f230b80cd6d610cf2fe78a3a9 Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Mon, 1 Jan 2024 20:01:13 +0100 Subject: [PATCH 1/2] Added parsing of conditional expressions in workflows. Things like "a = function_1() if test() else function_2()" are now valid as workflow code. This removes the need for having to call/define a function that just does the branching. Furthermore, since the conditional expression is now understood and parsed into a custom "ConditionalExpressionSyntaxNode", only the branch needed (as decided by `test()`) is executed. Comparison syntax (e.g. "a == b") is also parsed into a custom node, therefore it is now allowed in the syntax of workflows. --- src/sisl/nodes/node.py | 33 +++++-- src/sisl/nodes/syntax_nodes.py | 95 +++++++++++++++++++++ src/sisl/nodes/tests/test_syntax_nodes.py | 51 ++++++++++- src/sisl/nodes/workflow.py | 66 +++++++++++++- src/sisl/viz/plots/bands.py | 13 ++- src/sisl/viz/plots/geometry.py | 9 +- src/sisl/viz/plots/pdos.py | 6 +- src/sisl/viz/processors/logic.py | 14 +-- src/sisl/viz/processors/tests/test_logic.py | 21 +---- 9 files changed, 251 insertions(+), 57 deletions(-) diff --git a/src/sisl/nodes/node.py b/src/sisl/nodes/node.py index d3c3a031ea..98887926bd 100644 --- a/src/sisl/nodes/node.py +++ b/src/sisl/nodes/node.py @@ -362,9 +362,34 @@ def _sanitize_inputs( def evaluate_input_node(node: Node): return node.get() - def get(self): + def _get_evaluated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Evaluates all inputs. + + This function is ONLY called by the get method. + + The default implementation just goes over the inputs and, if they are nodes, + makes them compute their output. But some nodes might need more complex things, + like only evaluating some inputs depending on the value of other inputs. + + Parameters + ---------- + inputs : Dict[str, Any] + The input dictionary, possibly containing nodes to evaluate. + """ # Map all inputs to their values. That is, if they are nodes, call the get # method on them so that we get the updated output. This recursively evaluates nodes. + return self.map_inputs( + inputs=inputs, + func=self.evaluate_input_node, + only_nodes=True, + ) + + def get(self): + """Returns the output of the node, possibly running the computation. + + The computation of the node is only performed if the output is outdated, + otherwise this function just returns the stored output. + """ self._logger.setLevel(getattr(logging, self.context["log_level"].upper())) logs = logging.StreamHandler(StringIO()) @@ -375,11 +400,7 @@ def get(self): self._logger.debug("Getting output from node...") self._logger.debug(f"Raw inputs: {self._inputs}") - evaluated_inputs = self.map_inputs( - inputs=self._inputs, - func=self.evaluate_input_node, - only_nodes=True, - ) + evaluated_inputs = self._get_evaluated_inputs(self._inputs) self._logger.debug(f"Evaluated inputs: {evaluated_inputs}") diff --git a/src/sisl/nodes/syntax_nodes.py b/src/sisl/nodes/syntax_nodes.py index 7e5e347e69..9885e2c123 100644 --- a/src/sisl/nodes/syntax_nodes.py +++ b/src/sisl/nodes/syntax_nodes.py @@ -1,3 +1,6 @@ +import operator +from typing import Any, Dict + from .node import Node @@ -21,3 +24,95 @@ class DictSyntaxNode(SyntaxNode): @staticmethod def function(**items): return items + + +class ConditionalExpressionSyntaxNode(SyntaxNode): + _outdate_due_to_inputs: bool = False + + def _get_evaluated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + """Evaluate the inputs of this node. + + This function overwrites the default implementation in Node, because + we want to evaluate only the path that we are going to take. + + Parameters + ---------- + inputs : dict + The inputs to this node. + """ + + evaluated = {} + + # Get the state of the test input, which determines the path that we are going to take. + evaluated["test"] = ( + self.evaluate_input_node(inputs["test"]) + if isinstance(inputs["test"], Node) + else inputs["test"] + ) + + # Evaluate only the path that we are going to take. + if evaluated["test"]: + evaluated["true"] = ( + self.evaluate_input_node(inputs["true"]) + if isinstance(inputs["true"], Node) + else inputs["true"] + ) + evaluated["false"] = self._prev_evaluated_inputs.get("false") + else: + evaluated["false"] = ( + self.evaluate_input_node(inputs["false"]) + if isinstance(inputs["false"], Node) + else inputs["false"] + ) + evaluated["true"] = self._prev_evaluated_inputs.get("true") + + return evaluated + + def update_inputs(self, **inputs): + # This is just a wrapper over the normal update_inputs, which makes + # sure that the node is only marked as outdated if the input that + # is being used has changed. Note that here we just create a flag, + # which is then used in _receive_outdated. (_receive_outdated is + # called by super().update_inputs()) + current_test = self._prev_evaluated_inputs["test"] + + self._outdate_due_to_inputs = len(inputs) > 0 + if "test" not in inputs: + if current_test and ("true" not in inputs): + self._outdate_due_to_inputs = False + elif not current_test and ("false" not in inputs): + self._outdate_due_to_inputs = False + + try: + super().update_inputs(**inputs) + except: + self._outdate_due_to_inputs = False + raise + + def _receive_outdated(self): + # Relevant inputs have been updated, mark this node as outdated. + if self._outdate_due_to_inputs: + return super()._receive_outdated() + + # We avoid marking this node as outdated if the outdated input + # is not the one being returned. + for k in self._input_nodes: + if self._input_nodes[k]._outdated: + if k == "test": + return super()._receive_outdated() + elif k == "true": + if self._prev_evaluated_inputs["test"]: + return super()._receive_outdated() + elif k == "false": + if not self._prev_evaluated_inputs["test"]: + return super()._receive_outdated() + + @staticmethod + def function(test, true, false): + return true if test else false + + +class CompareSyntaxNode(SyntaxNode): + @staticmethod + def function(left, op: str, right): + return getattr(operator, op)(left, right) diff --git a/src/sisl/nodes/tests/test_syntax_nodes.py b/src/sisl/nodes/tests/test_syntax_nodes.py index 6267491f8f..7b8e1c11f3 100644 --- a/src/sisl/nodes/tests/test_syntax_nodes.py +++ b/src/sisl/nodes/tests/test_syntax_nodes.py @@ -1,4 +1,11 @@ -from sisl.nodes.syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode +from sisl.nodes.node import ConstantNode +from sisl.nodes.syntax_nodes import ( + CompareSyntaxNode, + ConditionalExpressionSyntaxNode, + DictSyntaxNode, + ListSyntaxNode, + TupleSyntaxNode, +) from sisl.nodes.workflow import Workflow @@ -14,6 +21,38 @@ def test_dict_syntax_node(): assert DictSyntaxNode(a="b", c="d", e="f").get() == {"a": "b", "c": "d", "e": "f"} +def test_cond_expr_node(): + node = ConditionalExpressionSyntaxNode(test=True, true=1, false=2) + + assert node.get() == 1 + node.update_inputs(test=False) + + assert node._outdated + assert node.get() == 2 + + node.update_inputs(true=3) + assert not node._outdated + + # Check that only one path is evaluated. + input1 = ConstantNode(1) + input2 = ConstantNode(2) + + node = ConditionalExpressionSyntaxNode(test=True, true=input1, false=input2) + + assert node.get() == 1 + assert input1._nupdates == 1 + assert input2._nupdates == 0 + + +def test_compare_syntax_node(): + assert CompareSyntaxNode(1, "eq", 2).get() == False + assert CompareSyntaxNode(1, "ne", 2).get() == True + assert CompareSyntaxNode(1, "gt", 2).get() == False + assert CompareSyntaxNode(1, "lt", 2).get() == True + assert CompareSyntaxNode(1, "ge", 2).get() == False + assert CompareSyntaxNode(1, "le", 2).get() == True + + def test_workflow_with_syntax(): def f(a): return [a] @@ -29,3 +68,13 @@ def f(a): return {"a": a} assert Workflow.from_func(f)(2).get() == {"a": 2} + + def f(a, b, c): + return b if a else c + + assert Workflow.from_func(f)(False, 1, 2).get() == 2 + + def f(a, b): + return a != b + + assert Workflow.from_func(f)(1, 2).get() == True diff --git a/src/sisl/nodes/workflow.py b/src/sisl/nodes/workflow.py index 7780276e27..67e36bf17e 100644 --- a/src/sisl/nodes/workflow.py +++ b/src/sisl/nodes/workflow.py @@ -14,7 +14,13 @@ from .context import temporal_context from .node import DummyInputValue, Node -from .syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode +from .syntax_nodes import ( + CompareSyntaxNode, + ConditionalExpressionSyntaxNode, + DictSyntaxNode, + ListSyntaxNode, + TupleSyntaxNode, +) from .utils import traverse_tree_backward, traverse_tree_forward register_environ_variable( @@ -996,6 +1002,15 @@ def _set_output(self, value): class NodeConverter(ast.NodeTransformer): """AST transformer that converts a function into a workflow.""" + ast_to_operator = { + ast.Eq: "eq", + ast.NotEq: "ne", + ast.Lt: "lt", + ast.LtE: "le", + ast.Gt: "gt", + ast.GtE: "ge", + } + def __init__( self, *args, @@ -1101,6 +1116,53 @@ def visit_Dict(self, node: ast.Dict) -> Any: return new_node + def visit_IfExp(self, node: ast.IfExp) -> Any: + """Converts the if expression syntax into a call to the ConditionalExpressionSyntaxNode.""" + new_node = ast.Call( + func=ast.Name(id="ConditionalExpressionSyntaxNode", ctx=ast.Load()), + args=[ + self.visit(node.test), + self.visit(node.body), + self.visit(node.orelse), + ], + keywords=[], + ) + + ast.fix_missing_locations(new_node) + + return new_node + + def visit_Compare(self, node: ast.Compare) -> Any: + """Converts the comparison syntax into CompareSyntaxNode call.""" + if len(node.ops) > 1: + return self.generic_visit(node) + + op = node.ops[0] + if op.__class__ not in self.ast_to_operator: + return self.generic_visit(node) + + new_node = ast.Call( + func=ast.Name(id="CompareSyntaxNode", ctx=ast.Load()), + args=[ + self.visit(node.left), + ast.Constant(value=self.ast_to_operator[op.__class__], ctx=ast.Load()), + self.visit(node.comparators[0]), + ], + keywords=[], + ) + + ast.fix_missing_locations(new_node) + + # new_node = ast.Call( + # func=ast.Name(id=self.ast_to_operator[op.__class__], ctx=ast.Load()), + # args=[self.visit(node.left), self.visit(node.comparators[0])], + # keywords=[], + # ) + + # ast.fix_missing_locations(new_node) + + return new_node + def nodify_func( func: FunctionType, @@ -1183,6 +1245,8 @@ def nodify_func( "ListSyntaxNode": ListSyntaxNode, "TupleSyntaxNode": TupleSyntaxNode, "DictSyntaxNode": DictSyntaxNode, + "ConditionalExpressionSyntaxNode": ConditionalExpressionSyntaxNode, + "CompareSyntaxNode": CompareSyntaxNode, **func_namespace, } if assign_fn_key is not None: diff --git a/src/sisl/viz/plots/bands.py b/src/sisl/viz/plots/bands.py index a4b97ea47c..229f299916 100644 --- a/src/sisl/viz/plots/bands.py +++ b/src/sisl/viz/plots/bands.py @@ -12,7 +12,6 @@ from ..plotutils import random_color from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands from ..processors.data import accept_data -from ..processors.logic import matches from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data from ..processors.xarray import scale_variable from .orbital_groups_plot import OrbitalGroupsPlot @@ -95,8 +94,8 @@ def bands_plot( ) # Determine what goes on each axis - x = matches(E_axis, "x", ret_true="E", ret_false="k") - y = matches(E_axis, "y", ret_true="E", ret_false="k") + x = "E" if E_axis == "x" else "k" + y = "E" if E_axis == "y" else "k" # Get the actions to plot lines bands_plottings = draw_xarray_xy( @@ -267,12 +266,10 @@ def fatbands_plot( ) # Determine what goes on each axis - x = matches(E_axis, "x", ret_true="E", ret_false="k") - y = matches(E_axis, "y", ret_true="E", ret_false="k") + x = "E" if E_axis == "x" else "k" + y = "E" if E_axis == "y" else "k" - sanitized_fatbands_mode = matches( - groups, [], ret_true="none", ret_false=fatbands_mode - ) + sanitized_fatbands_mode = "none" if groups == [] else fatbands_mode # Get the actions to plot lines fatbands_plottings = draw_xarray_xy( diff --git a/src/sisl/viz/plots/geometry.py b/src/sisl/viz/plots/geometry.py index 5ebc90760b..a8bc8234b8 100644 --- a/src/sisl/viz/plots/geometry.py +++ b/src/sisl/viz/plots/geometry.py @@ -28,7 +28,6 @@ style_bonds, tile_data_sc, ) -from ..processors.logic import matches, switch from ..processors.xarray import scale_variable, select @@ -180,7 +179,7 @@ def geometry_plot( # thread/process, potentially increasing speed. parsed_atom_style = parse_atoms_style(geometry, atoms_style=atoms_style) atoms_dataset = add_xyz_to_dataset(parsed_atom_style) - atoms_filter = switch(show_atoms, sanitized_atoms, []) + atoms_filter = sanitized_atoms if show_atoms else [] filtered_atoms = select(atoms_dataset, "atom", atoms_filter) tiled_atoms = tile_data_sc(filtered_atoms, nsc=nsc) sc_atoms = stack_sc_data(tiled_atoms, newname="sc_atom", dims=["atom"]) @@ -205,7 +204,7 @@ def geometry_plot( # Here we start to process bonds bonds = find_all_bonds(geometry) - show_bonds = matches(ndim, 1, False, show_bonds) + show_bonds = show_bonds if ndim > 1 else False styled_bonds = style_bonds(bonds, bonds_style) bonds_dataset = add_xyz_to_bonds_dataset(styled_bonds) bonds_filter = sanitize_bonds_selection( @@ -230,7 +229,7 @@ def geometry_plot( ) # And now the cell - show_cell = matches(ndim, 1, False, show_cell) + show_cell = show_cell if ndim > 1 else False cell_plottings = cell_plot_actions( cell=geometry, show_cell=show_cell, @@ -376,7 +375,7 @@ def sites_plot( ) # And now the cell - show_cell = matches(ndim, 1, False, show_cell) + show_cell = show_cell if ndim > 1 else show_cell cell_plottings = cell_plot_actions( cell=fake_geometry, show_cell=show_cell, diff --git a/src/sisl/viz/plots/pdos.py b/src/sisl/viz/plots/pdos.py index f9636b9c60..cc705ab28b 100644 --- a/src/sisl/viz/plots/pdos.py +++ b/src/sisl/viz/plots/pdos.py @@ -11,7 +11,7 @@ from ..plot import Plot from ..plotters.xarray import draw_xarray_xy from ..processors.data import accept_data -from ..processors.logic import matches, swap +from ..processors.logic import swap from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data from ..processors.xarray import filter_energy_range, scale_variable from .orbital_groups_plot import OrbitalGroupsPlot @@ -65,8 +65,8 @@ def pdos_plot( ) # Determine what goes on each axis - x = matches(E_axis, "x", ret_true="E", ret_false="PDOS") - y = matches(E_axis, "y", ret_true="E", ret_false="PDOS") + x = "E" if E_axis == "x" else "PDOS" + y = "E" if E_axis == "y" else "PDOS" dependent_axis = swap(E_axis, ("x", "y")) diff --git a/src/sisl/viz/processors/logic.py b/src/sisl/viz/processors/logic.py index 5c3dcfd08e..2dd4dd4c12 100644 --- a/src/sisl/viz/processors/logic.py +++ b/src/sisl/viz/processors/logic.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, TypeVar, Union +from typing import Tuple, TypeVar, Union T1 = TypeVar("T1") T2 = TypeVar("T2") @@ -12,15 +12,3 @@ def swap(val: Union[T1, T2], vals: Tuple[T1, T2]) -> Union[T1, T2]: return vals[0] else: raise ValueError(f"Value {val} not in {vals}") - - -def matches( - first: Any, second: Any, ret_true: T1 = True, ret_false: T2 = False -) -> Union[T1, T2]: - """If first matches second, return ret_true, else return ret_false.""" - return ret_true if first == second else ret_false - - -def switch(obj: Any, ret_true: T1, ret_false: T2) -> Union[T1, T2]: - """If obj is True, return ret_true, else return ret_false.""" - return ret_true if obj else ret_false diff --git a/src/sisl/viz/processors/tests/test_logic.py b/src/sisl/viz/processors/tests/test_logic.py index b09117b0dd..956d957af1 100644 --- a/src/sisl/viz/processors/tests/test_logic.py +++ b/src/sisl/viz/processors/tests/test_logic.py @@ -1,6 +1,6 @@ import pytest -from sisl.viz.processors.logic import matches, swap, switch +from sisl.viz.processors.logic import swap pytestmark = [pytest.mark.viz, pytest.mark.processors] @@ -11,22 +11,3 @@ def test_swap(): with pytest.raises(ValueError): swap(3, (1, 2)) - - -def test_matches(): - assert matches(1, 1) == True - assert matches(1, 2) == False - - assert matches(1, 1, "a", "b") == "a" - assert matches(1, 2, "a", "b") == "b" - - assert matches(1, 1, "a") == "a" - assert matches(1, 2, "a") == False - - assert matches(1, 1, ret_false="b") == True - assert matches(1, 2, ret_false="b") == "b" - - -def test_switch(): - assert switch(True, "a", "b") == "a" - assert switch(False, "a", "b") == "b" From 117ca49658cc615d254df255413c40656f3d2d0e Mon Sep 17 00:00:00 2001 From: Pol Febrer Date: Mon, 1 Jan 2024 20:23:49 +0100 Subject: [PATCH 2/2] Better display of conditionals on workflow diagrams --- src/sisl/nodes/node.py | 4 ++++ src/sisl/nodes/syntax_nodes.py | 18 ++++++++++++++++++ src/sisl/nodes/workflow.py | 4 +++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/sisl/nodes/node.py b/src/sisl/nodes/node.py index 98887926bd..894eecbbfa 100644 --- a/src/sisl/nodes/node.py +++ b/src/sisl/nodes/node.py @@ -677,6 +677,10 @@ def _maybe_autoupdate(self): if not self.context["lazy"]: self.get() + def get_diagram_label(self): + """Returns the label to be used in diagrams when displaying this node.""" + return None + class DummyInputValue(Node): """A dummy node that can be used as a placeholder for input values.""" diff --git a/src/sisl/nodes/syntax_nodes.py b/src/sisl/nodes/syntax_nodes.py index 9885e2c123..b94df1f256 100644 --- a/src/sisl/nodes/syntax_nodes.py +++ b/src/sisl/nodes/syntax_nodes.py @@ -111,8 +111,26 @@ def _receive_outdated(self): def function(test, true, false): return true if test else false + def get_diagram_label(self): + """Returns the label to be used in diagrams when displaying this node.""" + return "if/else" + class CompareSyntaxNode(SyntaxNode): + _op_to_symbol = { + "eq": "==", + "ne": "!=", + "gt": ">", + "lt": "<", + "ge": ">=", + "le": "<=", + None: "compare", + } + @staticmethod def function(left, op: str, right): return getattr(operator, op)(left, right) + + def get_diagram_label(self): + """Returns the label to be used in diagrams when displaying this node.""" + return self._op_to_symbol.get(self._prev_evaluated_inputs.get("op")) diff --git a/src/sisl/nodes/workflow.py b/src/sisl/nodes/workflow.py index 67e36bf17e..c96ba1f357 100644 --- a/src/sisl/nodes/workflow.py +++ b/src/sisl/nodes/workflow.py @@ -343,8 +343,9 @@ def rgb2gray(rgb): for node in nodes: graph_node = graph.nodes[node] + node_obj = self._workflow.dryrun_nodes.get(node) + if node_help: - node_obj = self._workflow.dryrun_nodes.get(node) title = ( _get_node_inputs_str(node_obj) if node_obj is not None else "" ) @@ -359,6 +360,7 @@ def rgb2gray(rgb): "level": level, "title": title, "font": font, + "label": node_obj.get_diagram_label(), **node_props, } )