Skip to content

Commit

Permalink
Merge pull request #666 from pfebrer/conditional_expressions
Browse files Browse the repository at this point in the history
Support for conditional expressions on workflows.
  • Loading branch information
zerothi authored Jan 4, 2024
2 parents c667b7e + 117ca49 commit 25c1b3a
Show file tree
Hide file tree
Showing 9 changed files with 276 additions and 58 deletions.
37 changes: 31 additions & 6 deletions src/sisl/nodes/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,9 +362,34 @@ def _sanitize_inputs(
def evaluate_input_node(node: Node):
return node.get()

def get(self):
def _get_evaluated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluates all inputs.
This function is ONLY called by the get method.
The default implementation just goes over the inputs and, if they are nodes,
makes them compute their output. But some nodes might need more complex things,
like only evaluating some inputs depending on the value of other inputs.
Parameters
----------
inputs : Dict[str, Any]
The input dictionary, possibly containing nodes to evaluate.
"""
# Map all inputs to their values. That is, if they are nodes, call the get
# method on them so that we get the updated output. This recursively evaluates nodes.
return self.map_inputs(
inputs=inputs,
func=self.evaluate_input_node,
only_nodes=True,
)

def get(self):
"""Returns the output of the node, possibly running the computation.
The computation of the node is only performed if the output is outdated,
otherwise this function just returns the stored output.
"""
self._logger.setLevel(getattr(logging, self.context["log_level"].upper()))

logs = logging.StreamHandler(StringIO())
Expand All @@ -375,11 +400,7 @@ def get(self):
self._logger.debug("Getting output from node...")
self._logger.debug(f"Raw inputs: {self._inputs}")

evaluated_inputs = self.map_inputs(
inputs=self._inputs,
func=self.evaluate_input_node,
only_nodes=True,
)
evaluated_inputs = self._get_evaluated_inputs(self._inputs)

self._logger.debug(f"Evaluated inputs: {evaluated_inputs}")

Expand Down Expand Up @@ -656,6 +677,10 @@ def _maybe_autoupdate(self):
if not self.context["lazy"]:
self.get()

def get_diagram_label(self):
"""Returns the label to be used in diagrams when displaying this node."""
return None


class DummyInputValue(Node):
"""A dummy node that can be used as a placeholder for input values."""
Expand Down
113 changes: 113 additions & 0 deletions src/sisl/nodes/syntax_nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import operator
from typing import Any, Dict

from .node import Node


Expand All @@ -21,3 +24,113 @@ class DictSyntaxNode(SyntaxNode):
@staticmethod
def function(**items):
return items


class ConditionalExpressionSyntaxNode(SyntaxNode):
_outdate_due_to_inputs: bool = False

def _get_evaluated_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
"""Evaluate the inputs of this node.
This function overwrites the default implementation in Node, because
we want to evaluate only the path that we are going to take.
Parameters
----------
inputs : dict
The inputs to this node.
"""

evaluated = {}

# Get the state of the test input, which determines the path that we are going to take.
evaluated["test"] = (
self.evaluate_input_node(inputs["test"])
if isinstance(inputs["test"], Node)
else inputs["test"]
)

# Evaluate only the path that we are going to take.
if evaluated["test"]:
evaluated["true"] = (
self.evaluate_input_node(inputs["true"])
if isinstance(inputs["true"], Node)
else inputs["true"]
)
evaluated["false"] = self._prev_evaluated_inputs.get("false")
else:
evaluated["false"] = (
self.evaluate_input_node(inputs["false"])
if isinstance(inputs["false"], Node)
else inputs["false"]
)
evaluated["true"] = self._prev_evaluated_inputs.get("true")

return evaluated

def update_inputs(self, **inputs):
# This is just a wrapper over the normal update_inputs, which makes
# sure that the node is only marked as outdated if the input that
# is being used has changed. Note that here we just create a flag,
# which is then used in _receive_outdated. (_receive_outdated is
# called by super().update_inputs())
current_test = self._prev_evaluated_inputs["test"]

self._outdate_due_to_inputs = len(inputs) > 0
if "test" not in inputs:
if current_test and ("true" not in inputs):
self._outdate_due_to_inputs = False
elif not current_test and ("false" not in inputs):
self._outdate_due_to_inputs = False

try:
super().update_inputs(**inputs)
except:
self._outdate_due_to_inputs = False
raise

def _receive_outdated(self):
# Relevant inputs have been updated, mark this node as outdated.
if self._outdate_due_to_inputs:
return super()._receive_outdated()

# We avoid marking this node as outdated if the outdated input
# is not the one being returned.
for k in self._input_nodes:
if self._input_nodes[k]._outdated:
if k == "test":
return super()._receive_outdated()
elif k == "true":
if self._prev_evaluated_inputs["test"]:
return super()._receive_outdated()
elif k == "false":
if not self._prev_evaluated_inputs["test"]:
return super()._receive_outdated()

@staticmethod
def function(test, true, false):
return true if test else false

def get_diagram_label(self):
"""Returns the label to be used in diagrams when displaying this node."""
return "if/else"


class CompareSyntaxNode(SyntaxNode):
_op_to_symbol = {
"eq": "==",
"ne": "!=",
"gt": ">",
"lt": "<",
"ge": ">=",
"le": "<=",
None: "compare",
}

@staticmethod
def function(left, op: str, right):
return getattr(operator, op)(left, right)

def get_diagram_label(self):
"""Returns the label to be used in diagrams when displaying this node."""
return self._op_to_symbol.get(self._prev_evaluated_inputs.get("op"))
51 changes: 50 additions & 1 deletion src/sisl/nodes/tests/test_syntax_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from sisl.nodes.syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode
from sisl.nodes.node import ConstantNode
from sisl.nodes.syntax_nodes import (
CompareSyntaxNode,
ConditionalExpressionSyntaxNode,
DictSyntaxNode,
ListSyntaxNode,
TupleSyntaxNode,
)
from sisl.nodes.workflow import Workflow


Expand All @@ -14,6 +21,38 @@ def test_dict_syntax_node():
assert DictSyntaxNode(a="b", c="d", e="f").get() == {"a": "b", "c": "d", "e": "f"}


def test_cond_expr_node():
node = ConditionalExpressionSyntaxNode(test=True, true=1, false=2)

assert node.get() == 1
node.update_inputs(test=False)

assert node._outdated
assert node.get() == 2

node.update_inputs(true=3)
assert not node._outdated

# Check that only one path is evaluated.
input1 = ConstantNode(1)
input2 = ConstantNode(2)

node = ConditionalExpressionSyntaxNode(test=True, true=input1, false=input2)

assert node.get() == 1
assert input1._nupdates == 1
assert input2._nupdates == 0


def test_compare_syntax_node():
assert CompareSyntaxNode(1, "eq", 2).get() == False
assert CompareSyntaxNode(1, "ne", 2).get() == True
assert CompareSyntaxNode(1, "gt", 2).get() == False
assert CompareSyntaxNode(1, "lt", 2).get() == True
assert CompareSyntaxNode(1, "ge", 2).get() == False
assert CompareSyntaxNode(1, "le", 2).get() == True


def test_workflow_with_syntax():
def f(a):
return [a]
Expand All @@ -29,3 +68,13 @@ def f(a):
return {"a": a}

assert Workflow.from_func(f)(2).get() == {"a": 2}

def f(a, b, c):
return b if a else c

assert Workflow.from_func(f)(False, 1, 2).get() == 2

def f(a, b):
return a != b

assert Workflow.from_func(f)(1, 2).get() == True
70 changes: 68 additions & 2 deletions src/sisl/nodes/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@

from .context import temporal_context
from .node import DummyInputValue, Node
from .syntax_nodes import DictSyntaxNode, ListSyntaxNode, TupleSyntaxNode
from .syntax_nodes import (
CompareSyntaxNode,
ConditionalExpressionSyntaxNode,
DictSyntaxNode,
ListSyntaxNode,
TupleSyntaxNode,
)
from .utils import traverse_tree_backward, traverse_tree_forward

register_environ_variable(
Expand Down Expand Up @@ -337,8 +343,9 @@ def rgb2gray(rgb):
for node in nodes:
graph_node = graph.nodes[node]

node_obj = self._workflow.dryrun_nodes.get(node)

if node_help:
node_obj = self._workflow.dryrun_nodes.get(node)
title = (
_get_node_inputs_str(node_obj) if node_obj is not None else ""
)
Expand All @@ -353,6 +360,7 @@ def rgb2gray(rgb):
"level": level,
"title": title,
"font": font,
"label": node_obj.get_diagram_label(),
**node_props,
}
)
Expand Down Expand Up @@ -996,6 +1004,15 @@ def _set_output(self, value):
class NodeConverter(ast.NodeTransformer):
"""AST transformer that converts a function into a workflow."""

ast_to_operator = {
ast.Eq: "eq",
ast.NotEq: "ne",
ast.Lt: "lt",
ast.LtE: "le",
ast.Gt: "gt",
ast.GtE: "ge",
}

def __init__(
self,
*args,
Expand Down Expand Up @@ -1101,6 +1118,53 @@ def visit_Dict(self, node: ast.Dict) -> Any:

return new_node

def visit_IfExp(self, node: ast.IfExp) -> Any:
"""Converts the if expression syntax into a call to the ConditionalExpressionSyntaxNode."""
new_node = ast.Call(
func=ast.Name(id="ConditionalExpressionSyntaxNode", ctx=ast.Load()),
args=[
self.visit(node.test),
self.visit(node.body),
self.visit(node.orelse),
],
keywords=[],
)

ast.fix_missing_locations(new_node)

return new_node

def visit_Compare(self, node: ast.Compare) -> Any:
"""Converts the comparison syntax into CompareSyntaxNode call."""
if len(node.ops) > 1:
return self.generic_visit(node)

op = node.ops[0]
if op.__class__ not in self.ast_to_operator:
return self.generic_visit(node)

new_node = ast.Call(
func=ast.Name(id="CompareSyntaxNode", ctx=ast.Load()),
args=[
self.visit(node.left),
ast.Constant(value=self.ast_to_operator[op.__class__], ctx=ast.Load()),
self.visit(node.comparators[0]),
],
keywords=[],
)

ast.fix_missing_locations(new_node)

# new_node = ast.Call(
# func=ast.Name(id=self.ast_to_operator[op.__class__], ctx=ast.Load()),
# args=[self.visit(node.left), self.visit(node.comparators[0])],
# keywords=[],
# )

# ast.fix_missing_locations(new_node)

return new_node


def nodify_func(
func: FunctionType,
Expand Down Expand Up @@ -1183,6 +1247,8 @@ def nodify_func(
"ListSyntaxNode": ListSyntaxNode,
"TupleSyntaxNode": TupleSyntaxNode,
"DictSyntaxNode": DictSyntaxNode,
"ConditionalExpressionSyntaxNode": ConditionalExpressionSyntaxNode,
"CompareSyntaxNode": CompareSyntaxNode,
**func_namespace,
}
if assign_fn_key is not None:
Expand Down
13 changes: 5 additions & 8 deletions src/sisl/viz/plots/bands.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from ..plotutils import random_color
from ..processors.bands import calculate_gap, draw_gaps, filter_bands, style_bands
from ..processors.data import accept_data
from ..processors.logic import matches
from ..processors.orbital import get_orbital_queries_manager, reduce_orbital_data
from ..processors.xarray import scale_variable
from .orbital_groups_plot import OrbitalGroupsPlot
Expand Down Expand Up @@ -95,8 +94,8 @@ def bands_plot(
)

# Determine what goes on each axis
x = matches(E_axis, "x", ret_true="E", ret_false="k")
y = matches(E_axis, "y", ret_true="E", ret_false="k")
x = "E" if E_axis == "x" else "k"
y = "E" if E_axis == "y" else "k"

# Get the actions to plot lines
bands_plottings = draw_xarray_xy(
Expand Down Expand Up @@ -267,12 +266,10 @@ def fatbands_plot(
)

# Determine what goes on each axis
x = matches(E_axis, "x", ret_true="E", ret_false="k")
y = matches(E_axis, "y", ret_true="E", ret_false="k")
x = "E" if E_axis == "x" else "k"
y = "E" if E_axis == "y" else "k"

sanitized_fatbands_mode = matches(
groups, [], ret_true="none", ret_false=fatbands_mode
)
sanitized_fatbands_mode = "none" if groups == [] else fatbands_mode

# Get the actions to plot lines
fatbands_plottings = draw_xarray_xy(
Expand Down
Loading

0 comments on commit 25c1b3a

Please sign in to comment.