diff --git a/aiida_workgraph/engine/launch.py b/aiida_workgraph/engine/launch.py new file mode 100644 index 00000000..dfd3052c --- /dev/null +++ b/aiida_workgraph/engine/launch.py @@ -0,0 +1,217 @@ +from __future__ import annotations + +import time +import typing as t + +from aiida.common import InvalidOperation +from aiida.common.log import AIIDA_LOGGER +from aiida.manage import manager +from aiida.orm import ProcessNode + +from aiida.engine.processes.builder import ProcessBuilder +from aiida.engine.processes.functions import get_stack_size +from aiida.engine.processes.process import Process +from aiida.engine.utils import prepare_inputs, is_process_function + +import signal +import sys +import inspect +from typing import ( + Type, + Union, +) + + +from aiida.manage import get_manager + +__all__ = ("run_get_node", "submit") + +TYPE_RUN_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +# run can also be process function, but it is not clear what type this should be +TYPE_SUBMIT_PROCESS = t.Union[Process, t.Type[Process], ProcessBuilder] +LOGGER = AIIDA_LOGGER.getChild("engine.launch") + + +def run_get_node( + process_class, *args, **kwargs +) -> tuple[dict[str, t.Any] | None, "ProcessNode"]: + """Run the FunctionProcess with the supplied inputs in a local runner. + :param args: input arguments to construct the FunctionProcess + :param kwargs: input keyword arguments to construct the FunctionProcess + :return: tuple of the outputs of the process and the process node + """ + parent_pid = kwargs.pop("_parent_pid", None) + frame_delta = 1000 + frame_count = get_stack_size() + stack_limit = sys.getrecursionlimit() + LOGGER.info( + "Executing process function, current stack status: %d frames of %d", + frame_count, + stack_limit, + ) + # If the current frame count is more than 80% of the stack limit, or comes within 200 frames, increase the + # stack limit by ``frame_delta``. + if frame_count > min(0.8 * stack_limit, stack_limit - 200): + LOGGER.warning( + "Current stack contains %d frames which is close to the limit of %d. Increasing the limit by %d", + frame_count, + stack_limit, + frame_delta, + ) + sys.setrecursionlimit(stack_limit + frame_delta) + manager = get_manager() + runner = manager.get_runner() + inputs = process_class.create_inputs(*args, **kwargs) + # Remove all the known inputs from the kwargs + for port in process_class.spec().inputs: + kwargs.pop(port, None) + # If any kwargs remain, the spec should be dynamic, so we raise if it isn't + if kwargs and not process_class.spec().inputs.dynamic: + raise ValueError( + f"{function.__name__} does not support these kwargs: {kwargs.keys()}" + ) + process = process_class(inputs=inputs, runner=runner, parent_pid=parent_pid) + # Only add handlers for interrupt signal to kill the process if we are in a local and not a daemon runner. + # Without this check, running process functions in a daemon worker would be killed if the daemon is shutdown + current_runner = manager.get_runner() + original_handler = None + kill_signal = signal.SIGINT + if not current_runner.is_daemon_runner: + + def kill_process(_num, _frame): + """Send the kill signal to the process in the current scope.""" + LOGGER.critical( + "runner received interrupt, killing process %s", process.pid + ) + result = process.kill( + msg="Process was killed because the runner received an interrupt" + ) + return result + + # Store the current handler on the signal such that it can be restored after process has terminated + original_handler = signal.getsignal(kill_signal) + signal.signal(kill_signal, kill_process) + try: + result = process.execute() + finally: + # If the `original_handler` is set, that means the `kill_process` was bound, which needs to be reset + if original_handler: + signal.signal(signal.SIGINT, original_handler) + store_provenance = inputs.get("metadata", {}).get("store_provenance", True) + if not store_provenance: + process.node._storable = False + process.node._unstorable_message = ( + "cannot store node because it was run with `store_provenance=False`" + ) + return result, process.node + + +def instantiate_process( + runner: "Runner", + process: Union["Process", Type["Process"], "ProcessBuilder"], + _parent_pid, + **inputs, +) -> "Process": + """Return an instance of the process with the given inputs. The function can deal with various types + of the `process`: + + * Process instance: will simply return the instance + * ProcessBuilder instance: will instantiate the Process from the class and inputs defined within it + * Process class: will instantiate with the specified inputs + + If anything else is passed, a ValueError will be raised + + :param process: Process instance or class, CalcJobNode class or ProcessBuilder instance + :param inputs: the inputs for the process to be instantiated with + """ + + if isinstance(process, Process): + assert not inputs + assert runner is process.runner + return process + + if isinstance(process, ProcessBuilder): + builder = process + process_class = builder.process_class + inputs.update(**builder._inputs(prune=True)) + elif is_process_function(process): + process_class = process.process_class # type: ignore[attr-defined] + elif inspect.isclass(process) and issubclass(process, Process): + process_class = process + else: + raise ValueError( + f"invalid process {type(process)}, needs to be Process or ProcessBuilder" + ) + + process = process_class(runner=runner, inputs=inputs, parent_pid=_parent_pid) + + return process + + +def submit( + process: TYPE_SUBMIT_PROCESS, + inputs: dict[str, t.Any] | None = None, + *, + wait: bool = False, + wait_interval: int = 5, + **kwargs: t.Any, +) -> ProcessNode: + """Submit the process with the supplied inputs to the daemon immediately returning control to the interpreter. + + .. warning: this should not be used within another process. Instead, there one should use the ``submit`` method of + the wrapping process itself, i.e. use ``self.submit``. + + .. warning: submission of processes requires ``store_provenance=True``. + + :param process: the process class, instance or builder to submit + :param inputs: the input dictionary to be passed to the process + :param wait: when set to ``True``, the submission will be blocking and wait for the process to complete at which + point the function returns the calculation node. + :param wait_interval: the number of seconds to wait between checking the state of the process when ``wait=True``. + :param kwargs: inputs to be passed to the process. This is an alternative to the positional ``inputs`` argument. + :return: the calculation node of the process + """ + _parent_pid = kwargs.pop("_parent_pid", None) + inputs = prepare_inputs(inputs, **kwargs) + + # Submitting from within another process requires ``self.submit``` unless it is a work function, in which case the + # current process in the scope should be an instance of ``FunctionProcess``. + # if is_process_scoped() and not isinstance(Process.current(), FunctionProcess): + # raise InvalidOperation('Cannot use top-level `submit` from within another process, use `self.submit` instead') + + runner = manager.get_manager().get_runner() + assert runner.persister is not None, "runner does not have a persister" + assert runner.controller is not None, "runner does not have a controller" + + process_inited = instantiate_process( + runner, process, _parent_pid=_parent_pid, **inputs + ) + + # If a dry run is requested, simply forward to `run`, because it is not compatible with `submit`. We choose for this + # instead of raising, because in this way the user does not have to change the launcher when testing. The same goes + # for if `remote_folder` is present in the inputs, which means we are importing an already completed calculation. + if process_inited.metadata.get("dry_run", False) or "remote_folder" in inputs: + _, node = run_get_node(process_inited) + return node + + if not process_inited.metadata.store_provenance: + raise InvalidOperation("cannot submit a process with `store_provenance=False`") + + runner.persister.save_checkpoint(process_inited) + process_inited.close() + + # Do not wait for the future's result, because in the case of a single worker this would cock-block itself + runner.controller.continue_process(process_inited.pid, nowait=False, no_reply=True) + node = process_inited.node + + if not wait: + return node + + while not node.is_terminated: + LOGGER.report( + f"Process<{node.pk}> has not yet terminated, current state is `{node.process_state}`. " + f"Waiting for {wait_interval} seconds." + ) + time.sleep(wait_interval) + + return node diff --git a/aiida_workgraph/engine/scheduler.py b/aiida_workgraph/engine/scheduler.py index 04201860..ba7aa798 100644 --- a/aiida_workgraph/engine/scheduler.py +++ b/aiida_workgraph/engine/scheduler.py @@ -9,7 +9,7 @@ from plumpy import process_comms from plumpy.persistence import auto_persist -from plumpy.process_states import Continue, Wait +from plumpy.process_states import Continue, Wait, Finished, Running import kiwipy from aiida.common import exceptions @@ -29,7 +29,6 @@ construct_awaitable, ) from aiida.engine.processes.workchains.workchain import Protect, WorkChainSpec -from aiida.engine import run_get_node from aiida_workgraph.utils import create_and_pause_process from aiida_workgraph.task import Task from aiida_workgraph.utils import get_nested_dict, update_nested_dict @@ -216,6 +215,8 @@ def _insert_awaitable(self, awaitable: Awaitable) -> None: else: raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") + # Register the callback to be called when the awaitable is resolved + self._add_callback_to_awaitable(awaitable) self._awaitables.append( awaitable ) # add only if everything went ok, otherwise we end up in an inconsistent state @@ -269,6 +270,7 @@ def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: for key, value in kwargs.items(): awaitable = construct_awaitable(value) awaitable.key = key + awaitable.workgraph_pk = value.workgraph_pk self._insert_awaitable(awaitable) def _update_process_status(self) -> None: @@ -351,17 +353,22 @@ def _action_awaitables(self) -> None: # if the waitable already has a callback, skip if awaitable.pk in self.ctx._workgraph[pk]["_awaitable_actions"]: continue - if awaitable.target == AwaitableTarget.PROCESS: - callback = functools.partial( - self.call_soon, self._on_awaitable_finished, awaitable - ) - self.runner.call_on_process_finish(awaitable.pk, callback) - self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) - elif awaitable.target == "asyncio.tasks.Task": - # this is a awaitable task, the callback function is already set - self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) - else: - assert f"invalid awaitable target '{awaitable.target}'" + self._add_callback_to_awaitable(awaitable) + + def _add_callback_to_awaitable(self, awaitable: Awaitable) -> None: + """Add a callback to the awaitable.""" + pk = awaitable.workgraph_pk + if awaitable.target == AwaitableTarget.PROCESS: + callback = functools.partial( + self.call_soon, self._on_awaitable_finished, awaitable + ) + self.runner.call_on_process_finish(awaitable.pk, callback) + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + elif awaitable.target == "asyncio.tasks.Task": + # this is a awaitable task, the callback function is already set + self.ctx._workgraph[pk]["_awaitable_actions"].append(awaitable.pk) + else: + assert f"invalid awaitable target '{awaitable.target}'" def _on_awaitable_finished(self, awaitable: Awaitable) -> None: """Callback function, for when an awaitable process instance is completed. @@ -371,8 +378,12 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: :param awaitable: an Awaitable instance """ + print(f"Awaitable {awaitable.key} finished.") self.logger.debug(f"Awaitable {awaitable.key} finished.") pk = awaitable.workgraph_pk + node = load_node(awaitable.pk) + print("node: ", node) + print("state: ", node.process_state) if isinstance(awaitable.pk, int): self.logger.info( @@ -407,18 +418,28 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: self.set_task_state_info(pk, awaitable.key, "state", "KILLED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][ + awaitable.key + ], + "SKIPPED", ) self.report(f"Task: {awaitable.key} cancelled.") else: results = awaitable.result() - self.set_normal_task_results(awaitable.key, results) + self.set_normal_task_results( + awaitable.workgraph_pk, awaitable.key, results + ) except Exception as e: self.logger.error(f"Error in awaitable {awaitable.key}: {e}") self.set_task_state_info(pk, awaitable.key, "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._connectivity["child_node"][awaitable.key], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][ + awaitable.key + ], + "SKIPPED", ) self.report(f"Task: {awaitable.key} failed.") self.run_error_handlers(pk, awaitable.key) @@ -428,6 +449,7 @@ def _on_awaitable_finished(self, awaitable: Awaitable) -> None: # node finished, update the task state and result # udpate the task state + print(f"Update task state: {awaitable.key}") self.update_task_state(awaitable.workgraph_pk, awaitable.key) # try to resume the workgraph, if the workgraph is already resumed # by other awaitable, this will not work @@ -451,6 +473,7 @@ def setup(self) -> None: # track if the awaitable callback is added to the runner self.ctx._workgraph = {} + self.ctx._max_number_awaitables = 10000 awaitable = Awaitable( **{ "workgraph_pk": self.node.pk, @@ -463,29 +486,55 @@ def setup(self) -> None: self.ctx._workgraph[self.node.pk] = {"_awaitable_actions": []} self.to_context(scheduler=awaitable) # self.ctx._msgs = [] - # self.ctx._execution_count = {} + # self.ctx._workgraph[pk]["_execution_count"] = {} # data not to be persisted, because they are not serializable self._temp = {"awaitables": {}} + # self.launch_workgraph(122305) def launch_workgraph(self, pk: str) -> None: """Launch the workgraph.""" # create the workgraph process + self.report(f"Launch workgraph: {pk}") self.init_ctx_workgraph(pk) - self.set_task_results(pk) + self.ctx._workgraph[pk]["_node"].set_process_state(Running.LABEL) + self.init_task_results(pk) + self.continue_workgraph(pk) + + def init_ctx_workgraph(self, pk: int) -> None: + """Init the context from the workgraph data.""" + from aiida_workgraph.utils import update_nested_dict + + self.report(f"Init workgraph: {pk}") + # read the latest workgraph data + wgdata, node = self.read_wgdata_from_base(pk) + self.ctx._workgraph[pk] = { + "_awaitable_actions": {}, + "_new_data": {}, + "_execution_count": 1, + "_executed_tasks": [], + "_count": 0, + "_context": {}, + "_node": node, + } + for key, value in wgdata["context"].items(): + key = key.replace("__", ".") + update_nested_dict(self.ctx._workgraph[pk], key, value) + # set up the workgraph + self.setup_ctx_workgraph(pk, wgdata) - def setup_ctx_workgraph(self, wgdata: t.Dict[str, t.Any]) -> None: + def setup_ctx_workgraph(self, pk: int, wgdata: t.Dict[str, t.Any]) -> None: """setup the workgraph in the context.""" import cloudpickle as pickle - pk = wgdata["pk"] - self.ctx._workgraph[pk]["_tasks"] = wgdata["tasks"] - self.ctx._workgraph[pk]["_links"] = wgdata["links"] - self.ctx._workgraph[pk]["_connectivity"] = wgdata["connectivity"] - self.ctx._workgraph[pk]["_ctrl_links"] = wgdata["ctrl_links"] - self.ctx._workgraph[pk]["_workgraph"] = wgdata + self.report(f"Setup workgraph: {pk}") + self.ctx._workgraph[pk]["_tasks"] = wgdata.pop("tasks") + self.ctx._workgraph[pk]["_links"] = wgdata.pop("links") + self.ctx._workgraph[pk]["_connectivity"] = wgdata.pop("connectivity") + self.ctx._workgraph[pk]["_ctrl_links"] = wgdata.pop("ctrl_links") self.ctx._workgraph[pk]["_error_handlers"] = pickle.loads( - wgdata["error_handlers"] + wgdata.pop("error_handlers") ) + self.ctx._workgraph[pk]["_workgraph"] = wgdata self.ctx._workgraph[pk]["_awaitable_actions"] = [] def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: @@ -503,19 +552,19 @@ def read_wgdata_from_base(self, pk: int) -> t.Dict[str, t.Any]: wgdata["error_handlers"] = deserialize_unsafe(wgdata["error_handlers"]) return wgdata, node - def update_workgraph_from_base(self) -> None: + def update_workgraph_from_base(self, pk: int) -> None: """Update the ctx from base.extras.""" wgdata, _ = self.read_wgdata_from_base() for name, task in wgdata["tasks"].items(): task["results"] = self.ctx._workgraph[pk]["_tasks"][name].get("results") - self.setup_ctx_workgraph(wgdata) + self.setup_ctx_workgraph(pk, wgdata) def get_task(self, name: str): """Get task from the context.""" task = Task.from_dict(self.ctx._workgraph[pk]["_tasks"][name]) return task - def update_task(self, task: Task): + def update_task(self, pk, task: Task): """Update task in the context. This is used in error handlers to update the task parameters.""" self.ctx._workgraph[pk]["_tasks"][task.name][ @@ -546,31 +595,13 @@ def set_task_state_info(self, pk: int, name: str, key: str, value: any) -> None: ) self.ctx._workgraph[pk]["_tasks"][name][key] = value - def init_ctx_workgraph(self, pk: int) -> None: - """Init the context from the workgraph data.""" - from aiida_workgraph.utils import update_nested_dict - - # read the latest workgraph data - wgdata, node = self.read_wgdata_from_base(pk) - self.ctx._workgraph[pk] = { - "_awaitable_actions": {}, - "_new_data": {}, - "_executed_tasks": {}, - "_count": 0, - "_context": {}, - "_node": node, - } - for key, value in wgdata["_context"].items(): - key = key.replace("__", ".") - update_nested_dict(self.ctx._workgraph[pk], key, value) - # set up the workgraph - self.setup_ctx_workgraph(wgdata) - - def set_task_results(self, pk) -> None: + def init_task_results(self, pk) -> None: + """Init the task results.""" for name, task in self.ctx._workgraph[pk]["_tasks"].items(): if self.get_task_state_info(pk, name, "action").upper() == "RESET": self.reset_task(pk, task["name"]) - self.update_task_state(pk, name) + # only init the task results, and do not need to continue the workgraph + self.update_task_state(pk, name, continue_workgraph=False) def apply_action(self, msg: dict) -> None: @@ -605,6 +636,7 @@ def apply_task_actions(self, msg: dict) -> None: def reset_task( self, + pk: int, name: str, reset_process: bool = True, recursive: bool = True, @@ -629,7 +661,7 @@ def reset_task( self.reset_task(child_task, reset_process=False, recursive=False) if recursive: # reset its child tasks - names = self.ctx._connectivity["child_node"][name] + names = self.ctx._workgraph[pk]["_connectivity"]["child_node"][name] for name in names: self.reset_task(name, recursive=False) @@ -667,7 +699,13 @@ def kill_task(self, pk, name: str) -> None: self.logger.error(f"Error in killing task {name}: {e}") def continue_workgraph(self, pk: int) -> None: - print("Continue workgraph.") + is_finished, _ = self.is_workgraph_finished(pk) + if is_finished: + self.report(f"Workgraph {pk} finished.") + self.ctx._workgraph[pk]["_node"].set_process_state(Finished.LABEL) + self.ctx._workgraph[pk]["_node"].set_exit_status(0) + self.ctx._workgraph[pk]["_node"].seal() + return self.report("Continue workgraph.") # self.update_workgraph_from_base() task_to_run = [] @@ -689,17 +727,28 @@ def continue_workgraph(self, pk: int) -> None: if ready: task_to_run.append(name) # - self.report("tasks ready to run: {}".format(",".join(task_to_run))) - self.run_tasks(pk, task_to_run) + self.report( + "tasks ready to run in WorkGraph {}, tasks: {}".format( + pk, ",".join(task_to_run) + ) + ) + if len(task_to_run) > 0: + self.run_tasks(pk, task_to_run) - def update_task_state(self, pk: int, name: str) -> None: + def update_task_state( + self, pk: int, name: str, continue_workgraph: bool = True + ) -> None: """Update task state when the task is finished.""" + + print("update task state: ", pk, name) task = self.ctx._workgraph[pk]["_tasks"][name] # print(f"set task result: {name}") node = self.get_task_state_info(pk, name, "process") + print("node", node) if isinstance(node, orm.ProcessNode): # print(f"set task result: {name} process") state = node.process_state.value.upper() + print("state", state) if node.is_finished_ok: self.set_task_state_info(pk, task["name"], "state", state) if task["metadata"]["node_type"].upper() == "WORKGRAPH": @@ -719,12 +768,15 @@ def update_task_state(self, pk: int, name: str) -> None: self.report(f"Workgraph: {pk}, Task: {name} finished.") # all other states are considered as failed else: + print(f"set task result: {name} failed") task["results"] = node.outputs # self.ctx._new_data[name] = task["results"] self.set_task_state_info(pk, task["name"], "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._workgraph["_connectivity"]["child_node"][name], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) self.report(f"Workgraph: {pk}, Task: {name} failed.") self.run_error_handlers(pk, name) @@ -737,6 +789,8 @@ def update_task_state(self, pk: int, name: str) -> None: task.setdefault("results", None) self.update_parent_task_state(pk, name) + if continue_workgraph: + self.continue_workgraph(pk) def set_normal_task_results(self, pk, name, results): """Set the results of a normal task. @@ -757,7 +811,7 @@ def set_normal_task_results(self, pk, name, results): self.report(f"Workgraph: {pk}, Task: {name} finished.") self.update_parent_task_state(pk, name) - def update_parent_task_state(self, name: str) -> None: + def update_parent_task_state(self, pk, name: str) -> None: """Update parent task state.""" parent_task = self.ctx._workgraph[pk]["_tasks"][name]["parent_task"] if parent_task[0]: @@ -796,7 +850,7 @@ def update_zone_task_state(self, name: str) -> None: self.update_parent_task_state(pk, name) self.report(f"Task: {name} finished.") - def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: + def should_run_while_task(self, pk: int, name: str) -> tuple[bool, t.Any]: """Check if the while task should run.""" # check the conditions of the while task not_excess_max_iterations = ( @@ -806,7 +860,7 @@ def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: ] ) conditions = [not_excess_max_iterations] - _, kwargs, _, _, _ = self.get_inputs(name) + _, kwargs, _, _, _ = self.get_inputs(pk, name) if isinstance(kwargs["conditions"], list): for condition in kwargs["conditions"]: value = get_nested_dict(self.ctx, condition) @@ -820,7 +874,7 @@ def should_run_while_task(self, name: str) -> tuple[bool, t.Any]: def should_run_if_task(self, name: str) -> tuple[bool, t.Any]: """Check if the IF task should run.""" - _, kwargs, _, _, _ = self.get_inputs(name) + _, kwargs, _, _, _ = self.get_inputs(pk, name) flag = kwargs["conditions"] if kwargs["invert_condition"]: return not flag @@ -840,12 +894,12 @@ def are_childen_finished(self, pk, name: str) -> tuple[bool, t.Any]: break return finished, None - def run_error_handlers(self, pk, task_name: str) -> None: + def run_error_handlers(self, pk: int, task_name: str) -> None: """Run error handler.""" - node = self.get_task_state_info(task_name, "process") + node = self.get_task_state_info(pk, task_name, "process") if not node or not node.exit_status: return - for _, data in self.ctx._error_handlers.items(): + for _, data in self.ctx._workgraph[pk]["_error_handlers"].items(): if task_name in data["tasks"]: handler = data["handler"] metadata = data["tasks"][task_name] @@ -862,7 +916,7 @@ def is_workgraph_finished(self, pk) -> bool: is_finished = True failed_tasks = [] for name, task in self.ctx._workgraph[pk]["_tasks"].items(): - # self.update_task_state(name) + # self.update_task_state(pk, name) if self.get_task_state_info(pk, task["name"], "state") in [ "RUNNING", "CREATED", @@ -873,10 +927,13 @@ def is_workgraph_finished(self, pk) -> bool: elif self.get_task_state_info(pk, task["name"], "state") == "FAILED": failed_tasks.append(name) if is_finished: - if self.ctx._workgraph["workgraph_type"].upper() == "WHILE": + if ( + self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() + == "WHILE" + ): should_run = self.check_while_conditions(pk) is_finished = not should_run - if self.ctx._workgraph["workgraph_type"].upper() == "FOR": + if self.ctx._workgraph[pk]["_workgraph"]["workgraph_type"].upper() == "FOR": should_run = self.check_for_conditions(pk) is_finished = not should_run if is_finished and len(failed_tasks) > 0: @@ -892,17 +949,20 @@ def check_while_conditions(self, pk: int) -> bool: Run all condition tasks and check if all the conditions are True. """ self.report("Check while conditions.") - if self.ctx._execution_count >= self.ctx._max_iteration: + if ( + self.ctx._workgraph[pk]["_execution_count"] + >= self.ctx._workgraph[pk]["_max_iteration"] + ): self.report("Max iteration reached.") return False condition_tasks = [] - for c in self.ctx._workgraph["conditions"]: + for c in self.ctx._workgraph[pk]["conditions"]: task_name, socket_name = c.split(".") if "task_name" != "context": condition_tasks.append(task_name) self.run_tasks(condition_tasks, continue_workgraph=False) conditions = [] - for c in self.ctx._workgraph["conditions"]: + for c in self.ctx._workgraph[pk]["conditions"]: task_name, socket_name = c.split(".") if task_name == "context": conditions.append(self.ctx[socket_name]) @@ -912,21 +972,21 @@ def check_while_conditions(self, pk: int) -> bool: ) should_run = False not in conditions if should_run: - self.reset() - self.set_tasks_state(condition_tasks, "SKIPPED") + self.reset_workgraph(pk) + self.set_tasks_state(pk, condition_tasks, "SKIPPED") return should_run def check_for_conditions(self, pk: int) -> bool: - condition_tasks = [c[0] for c in self.ctx._workgraph["conditions"]] + condition_tasks = [c[0] for c in self.ctx._workgraph[pk]["conditions"]] self.run_tasks(condition_tasks) conditions = [self.ctx._count < len(self.ctx._sequence)] + [ self.ctx._workgraph[pk]["_tasks"][c[0]]["results"][c[1]] - for c in self.ctx._workgraph["conditions"] + for c in self.ctx._workgraph[pk]["conditions"] ] should_run = False not in conditions if should_run: - self.reset() - self.set_tasks_state(condition_tasks, "SKIPPED") + self.reset_workgraph(pk) + self.set_tasks_state(pk, condition_tasks, "SKIPPED") self.ctx["i"] = self.ctx._sequence[self.ctx._count] self.ctx._count += 1 return should_run @@ -939,6 +999,19 @@ def remove_executed_task(self, pk, name: str) -> None: if label.split(".")[0] != name ] + def add_task_link(self, pk, node: ProcessNode) -> None: + from aiida.common.links import LinkType + + parent_calc = self.ctx._workgraph[pk]["_node"] + if isinstance(node, orm.CalculationNode): + node.base.links.add_incoming( + parent_calc, LinkType.CALL_CALC, "CALL" + ) # TODO, self.metadata.call_link_label) + elif isinstance(node, orm.WorkflowNode): + node.base.links.add_incoming( + parent_calc, LinkType.CALL_WORK, "CALL" + ) # TODO, self.metadata.call_link_label) + def run_tasks( self, pk: int, names: t.List[str], continue_workgraph: bool = True ) -> None: @@ -955,6 +1028,8 @@ def run_tasks( create_data_node, update_nested_dict_with_special_keys, ) + from aiida_workgraph.engine.workgraph import WorkGraphEngine + from aiida_workgraph.engine import launch for name in names: # skip if the max number of awaitables is reached @@ -983,7 +1058,7 @@ def run_tasks( self.report(f"Run task: {name}, type: {task['metadata']['node_type']}") executor, _ = get_executor(task["executor"]) # print("executor: ", executor) - args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(name) + args, kwargs, var_args, var_kwargs, args_dict = self.get_inputs(pk, name) for i, key in enumerate( self.ctx._workgraph[pk]["_tasks"][name]["metadata"]["args"] ): @@ -999,7 +1074,7 @@ def run_tasks( if task["metadata"]["node_type"].upper() == "NODE": results = self.run_executor(executor, [], kwargs, var_args, var_kwargs) self.set_task_state_info(pk, name, "process", results) - self.update_task_state(name) + self.update_task_state(pk, name) if continue_workgraph: self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() == "DATA": @@ -1007,7 +1082,7 @@ def run_tasks( kwargs.pop(key, None) results = create_data_node(executor, args, kwargs) self.set_task_state_info(pk, name, "process", results) - self.update_task_state(name) + self.update_task_state(pk, name) self.ctx._new_data[name] = results if continue_workgraph: self.continue_workgraph(pk) @@ -1017,24 +1092,29 @@ def run_tasks( ]: kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) + kwargs["_parent_pid"] = pk try: # since aiida 2.5.0, we need to use args_dict to pass the args to the run_get_node if var_kwargs is None: - results, process = run_get_node(executor, **kwargs) + results, process = launch.run_get_node( + executor.process_class, **kwargs + ) else: - results, process = run_get_node( - executor, **kwargs, **var_kwargs + results, process = launch.run_get_node( + executor.process_class, **kwargs, **var_kwargs ) process.label = name # print("results: ", results) self.set_task_state_info(pk, name, "process", process) - self.update_task_state(name) + self.update_task_state(pk, name) except Exception as e: self.logger.error(f"Error in task {name}: {e}") self.set_task_state_info(pk, name, "state", "FAILED") # set child state to FAILED self.set_tasks_state( - self.ctx._connectivity["child_node"][name], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) self.report(f"Task: {name} failed.") # exclude the current tasks from the next run @@ -1044,6 +1124,7 @@ def run_tasks( # process = run_get_node(executor, *args, **kwargs) kwargs.setdefault("metadata", {}) kwargs["metadata"].update({"call_link_label": name}) + kwargs["_parent_pid"] = pk # transfer the args to kwargs if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") @@ -1057,9 +1138,10 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = self.submit(executor, **kwargs) + process = launch.submit(executor, **kwargs) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["GRAPH_BUILDER"]: @@ -1070,7 +1152,9 @@ def run_tasks( ] wg.parent_uuid = self.node.uuid inputs = wg.prepare_inputs(metadata={"call_link_label": name}) - # process = self.submit(WorkGraphEngine, inputs=inputs) + inputs["parent_pid"] = pk + process = launch.submit(WorkGraphEngine, inputs=inputs) + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.set_task_state_info(pk, name, "state", "RUNNING") self.to_context(**{name: process}) @@ -1078,7 +1162,9 @@ def run_tasks( from .utils import prepare_for_workgraph_task inputs, _ = prepare_for_workgraph_task(task, kwargs) - # process = self.submit(WorkGraphEngine, inputs=inputs) + inputs["parent_pid"] = pk + process = launch.submit(WorkGraphEngine, inputs=inputs) + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.set_task_state_info(pk, name, "state", "RUNNING") self.to_context(**{name: process}) @@ -1087,6 +1173,7 @@ def run_tasks( from .utils import prepare_for_python_task inputs = prepare_for_python_task(task, kwargs, var_kwargs) + inputs["parent_pid"] = pk # since aiida 2.5.0, we can pass inputs directly to the submit, no need to use **inputs if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") @@ -1100,9 +1187,10 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = self.submit(PythonJob, **inputs) + process = launch.submit(PythonJob, **inputs) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["SHELLJOB"]: @@ -1110,6 +1198,7 @@ def run_tasks( from .utils import prepare_for_shell_task inputs = prepare_for_shell_task(task, kwargs) + inputs["parent_pid"] = pk if self.get_task_state_info(pk, name, "action").upper() == "PAUSE": self.set_task_state_info(pk, name, "action", "") self.report(f"Task {name} is created and paused.") @@ -1122,9 +1211,10 @@ def run_tasks( self.set_task_state_info(pk, name, "state", "CREATED") process = process.node else: - process = self.submit(ShellJob, **inputs) + process = launch.submit(ShellJob, **inputs) self.set_task_state_info(pk, name, "state", "RUNNING") process.label = name + process.workgraph_pk = pk self.set_task_state_info(pk, name, "process", process) self.to_context(**{name: process}) elif task["metadata"]["node_type"].upper() in ["WHILE"]: @@ -1133,7 +1223,9 @@ def run_tasks( if not should_run: self.set_task_state_info(pk, name, "state", "FINISHED") self.set_tasks_state( - self.ctx._workgraph[pk]["_tasks"][name]["children"], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_tasks"][name]["children"], + "SKIPPED", ) self.update_parent_task_state(pk, name) self.report( @@ -1148,7 +1240,7 @@ def run_tasks( if should_run: self.set_task_state_info(pk, name, "state", "RUNNING") else: - self.set_tasks_state(task["children"], "SKIPPED") + self.set_tasks_state(pk, task["children"], "SKIPPED") self.update_zone_task_state(name) self.continue_workgraph(pk) elif task["metadata"]["node_type"].upper() in ["ZONE"]: @@ -1203,13 +1295,15 @@ def run_tasks( results = self.run_executor( executor, args, kwargs, var_args, var_kwargs ) - self.set_normal_task_results(name, results) + self.set_normal_task_results(pk, name, results) except Exception as e: self.logger.error(f"Error in task {name}: {e}") self.set_task_state_info(pk, name, "state", "FAILED") # set child tasks state to SKIPPED self.set_tasks_state( - self.ctx._connectivity["child_node"][name], "SKIPPED" + pk, + self.ctx._workgraph[pk]["_connectivity"]["child_node"][name], + "SKIPPED", ) self.report(f"Task: {name} failed.") self.run_error_handlers(pk, name) @@ -1238,7 +1332,7 @@ def construct_awaitable_function( return awaitable def get_inputs( - self, name: str + self, pk: int, name: str ) -> t.Tuple[ t.List[t.Any], t.Dict[str, t.Any], @@ -1348,7 +1442,7 @@ def update_context_variable(self, value: t.Any) -> t.Any: return get_nested_dict(self.ctx, name) return value - def task_set_context(self, name: str) -> None: + def task_set_context(self, pk, name: str) -> None: """Export task result to context.""" from aiida_workgraph.utils import update_nested_dict @@ -1357,7 +1451,7 @@ def task_set_context(self, name: str) -> None: result = self.ctx._workgraph[pk]["_tasks"][name]["results"][key] update_nested_dict(self.ctx, value, result) - def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: + def is_task_ready_to_run(self, pk, name: str) -> t.Tuple[bool, t.Optional[str]]: """Check if the task ready to run. For normal task and a zone task, we need to check its input tasks in the connectivity["zone"]. For task inside a zone, we need to check if the zone (parent task) is ready. @@ -1367,13 +1461,15 @@ def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: parent_states = [True, True] # if the task belongs to a parent zone if parent_task[0]: - state = self.get_task_state_info(parent_task[0], "state") + state = self.get_task_state_info(pk, parent_task[0], "state") if state not in ["RUNNING"]: parent_states[1] = False # check the input tasks of the zone # check if the zone input tasks are ready - for child_task_name in self.ctx._connectivity["zone"][name]["input_tasks"]: - if self.get_task_state_info(child_task_name, "state") not in [ + for child_task_name in self.ctx._workgraph[pk]["_connectivity"]["zone"][name][ + "input_tasks" + ]: + if self.get_task_state_info(pk, child_task_name, "state") not in [ "FINISHED", "SKIPPED", "FAILED", @@ -1383,20 +1479,20 @@ def is_task_ready_to_run(self, name: str) -> t.Tuple[bool, t.Optional[str]]: return all(parent_states), parent_states - def reset(self) -> None: - self.ctx._execution_count += 1 - self.set_tasks_state(self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED") + def reset_workgraph(self, pk) -> None: + self.ctx._workgraph[pk]["_execution_count"] += 1 + self.set_tasks_state(pk, self.ctx._workgraph[pk]["_tasks"].keys(), "PLANNED") self.ctx._workgraph[pk]["_executed_tasks"] = [] def set_tasks_state( - self, tasks: t.Union[t.List[str], t.Sequence[str]], value: str + self, pk: int, tasks: t.Union[t.List[str], t.Sequence[str]], value: str ) -> None: """Set tasks state""" for name in tasks: self.set_task_state_info(pk, name, "state", value) if "children" in self.ctx._workgraph[pk]["_tasks"][name]: self.set_tasks_state( - self.ctx._workgraph[pk]["_tasks"][name]["children"], value + pk, self.ctx._workgraph[pk]["_tasks"][name]["children"], value ) def run_executor( @@ -1516,7 +1612,10 @@ def finalize(self) -> t.Optional[ExitCode]: self.out_many(group_outputs) # output the new data self.out("new_data", self.ctx._new_data) - self.out("execution_count", orm.Int(self.ctx._execution_count).store()) + self.out( + "execution_count", + orm.Int(self.ctx._workgraph[pk]["_execution_count"]).store(), + ) self.report("Finalize workgraph.") for _, task in self.ctx._workgraph[pk]["_tasks"].items(): if self.get_task_state_info(pk, task["name"], "state") == "FAILED": diff --git a/aiida_workgraph/workgraph.py b/aiida_workgraph/workgraph.py index 21b132cf..88bcce5f 100644 --- a/aiida_workgraph/workgraph.py +++ b/aiida_workgraph/workgraph.py @@ -134,7 +134,7 @@ def submit( self.save(metadata=metadata) if self.process.process_state.value.upper() not in ["CREATED"]: raise ValueError(f"Process {self.process.pk} has already been submitted.") - self.continue_process_in_scheduler() + self.continue_process() # as long as we submit the process, it is a new submission, we should set restart_process to None self.restart_process = None if wait: @@ -415,7 +415,7 @@ def continue_process(self): process_controller = get_manager().get_process_controller() process_controller.continue_process(self.pk) - def continue_process_in_scheduler(self, scheduler_pk: int = 122006): + def continue_process_in_scheduler(self, scheduler_pk: int = 122744): """Ask the scheduler to pick up the process from the database and run it.""" from aiida_workgraph.utils.control import create_task_action