From fa630d3e560186d2856adaa30a75a636d8da7574 Mon Sep 17 00:00:00 2001 From: superstar54 Date: Sat, 14 Dec 2024 19:34:34 +0100 Subject: [PATCH] Using a pickled function object --- src/aiida_pythonjob/calculations/pythonjob.py | 202 +++++++++++------- src/aiida_pythonjob/calculations/utils.py | 108 ++++++++++ src/aiida_pythonjob/data/pickled_data.py | 8 + src/aiida_pythonjob/launch.py | 22 +- src/aiida_pythonjob/parsers/pythonjob.py | 81 +++++-- src/aiida_pythonjob/utils.py | 61 ++---- tests/test_parser.py | 63 ++++-- tests/test_pythonjob.py | 18 +- tests/test_utils.py | 7 +- 9 files changed, 400 insertions(+), 170 deletions(-) create mode 100644 src/aiida_pythonjob/calculations/utils.py diff --git a/src/aiida_pythonjob/calculations/pythonjob.py b/src/aiida_pythonjob/calculations/pythonjob.py index 5299efd..8b50bb3 100644 --- a/src/aiida_pythonjob/calculations/pythonjob.py +++ b/src/aiida_pythonjob/calculations/pythonjob.py @@ -1,4 +1,4 @@ -"""Calcjob to run a Python function on a remote computer.""" +"""Calcjob to run a Python function on a remote computer, either via raw source code or a pickled function.""" from __future__ import annotations @@ -11,7 +11,6 @@ from aiida.engine import CalcJob, CalcJobProcessSpec from aiida.orm import ( Data, - Dict, FolderData, List, RemoteData, @@ -24,7 +23,12 @@ class PythonJob(CalcJob): - """Calcjob to run a Python function on a remote computer.""" + """Calcjob to run a Python function on a remote computer. + + Supports two modes: + 1) Loading a pickled function object (function_data.pickled_function). + 2) Embedding raw source code for the function (function_data.source_code). + """ _internal_retrieve_list = [] _retrieve_singlefile_list = [] @@ -36,16 +40,16 @@ class PythonJob(CalcJob): @classmethod def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] - """Define the process specification, including its inputs, outputs and known exit codes. - - :param spec: the calculation job process spec to define. - """ + """Define the process specification, including its inputs, outputs and known exit codes.""" super().define(spec) - spec.input("function_data", valid_type=Dict, serializer=to_aiida_type, required=False) + spec.input_namespace("function_data") + spec.input("function_data.name", valid_type=Str, serializer=to_aiida_type) + spec.input("function_data.source_code", valid_type=Str, serializer=to_aiida_type, required=False) + spec.input("function_data.outputs", valid_type=List, serializer=to_aiida_type, required=False) + spec.input("function_data.pickled_function", valid_type=Data, required=False) + spec.input("function_data.mode", valid_type=Str, serializer=to_aiida_type, required=False) spec.input("process_label", valid_type=Str, serializer=to_aiida_type, required=False) - spec.input_namespace( - "function_inputs", valid_type=Data, required=False - ) # , serializer=serialize_to_aiida_nodes) + spec.input_namespace("function_inputs", valid_type=Data, required=False) spec.input( "parent_folder", valid_type=(RemoteData, FolderData, SinglefileData), @@ -57,8 +61,8 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] valid_type=Str, required=False, serializer=to_aiida_type, - help="""Default name of the subfolder that you want to create in the working directory, - in which you want to place the files taken from parent_folder""", + help="""Default name of the subfolder to create in the working directory + where the files from parent_folder are placed.""", ) spec.input( "parent_output_folder", @@ -86,7 +90,7 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] default=None, required=False, serializer=to_aiida_type, - help="The names of the files to retrieve", + help="Additional filenames to retrieve from the remote work directory", ) spec.outputs.dynamic = True # set default options (optional) @@ -97,7 +101,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] "num_machines": 1, "num_mpiprocs_per_machine": 1, } - # start exit codes - marker for docs spec.exit_code( 310, "ERROR_READING_OUTPUT_FILE", @@ -116,86 +119,131 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override] invalidates_cache=True, message="The number of results does not match the number of outputs.", ) + spec.exit_code( + 322, + "ERROR_IMPORT_CLOUDPICKLE_FAILED", + invalidates_cache=True, + message="Importing cloudpickle failed.\n{exception}\n{traceback}", + ) + spec.exit_code( + 323, + "ERROR_UNPICKLE_INPUTS_FAILED", + invalidates_cache=True, + message="Failed to unpickle inputs.\n{exception}\n{traceback}", + ) + spec.exit_code( + 324, + "ERROR_UNPICKLE_FUNCTION_FAILED", + invalidates_cache=True, + message="Failed to unpickle user function.\n{exception}\n{traceback}", + ) + spec.exit_code( + 325, + "ERROR_FUNCTION_EXECUTION_FAILED", + invalidates_cache=True, + message="Function execution failed.\n{exception}\n{traceback}", + ) + spec.exit_code( + 326, + "ERROR_PICKLE_RESULTS_FAILED", + invalidates_cache=True, + message="Failed to pickle results.\n{exception}\n{traceback}", + ) + spec.exit_code( + 327, + "ERROR_SCRIPT_FAILED", + invalidates_cache=True, + message="The script failed for an unknown reason.\n{exception}\n{traceback}", + ) - def _build_process_label(self) -> str: - """Use the function name as the process label. + def get_function_name(self) -> str: + """Return the name of the function to run.""" + if "name" in self.inputs.function_data: + name = self.inputs.function_data.name.value + else: + try: + name = self.inputs.function_data.pickled_function.value.__name__ + except AttributeError: + # If a user doesn't specify name, fallback to something generic + name = "anonymous_function" + return name - :returns: The process label to use for ``ProcessNode`` instances. - """ + def _build_process_label(self) -> str: + """Use the function name or an explicit label as the process label.""" if "process_label" in self.inputs: return self.inputs.process_label.value else: - data = self.inputs.function_data.get_dict() - return f"PythonJob<{data['name']}>" + name = self.get_function_name() + return f"PythonJob<{name}>" def on_create(self) -> None: """Called when a Process is created.""" - super().on_create() self.node.label = self._build_process_label() def prepare_for_submission(self, folder: Folder) -> CalcInfo: """Prepare the calculation for submission. - 1) Write the python script to the folder. + 1) Write the python script to the folder, depending on the mode (source vs. pickled function). 2) Write the inputs to a pickle file and save it to the folder. - - :param folder: A temporary folder on the local file system. - :returns: A :class:`aiida.common.datastructures.CalcInfo` instance. """ import cloudpickle as pickle + from aiida_pythonjob.calculations.utils import generate_script_py + dirpath = pathlib.Path(folder._abspath) - inputs: dict[str, t.Any] + # Prepare the dictionary of input arguments for the function + inputs: dict[str, t.Any] if self.inputs.function_inputs: inputs = dict(self.inputs.function_inputs) else: inputs = {} + + # Prepare the final subfolder name for the parent folder if "parent_folder_name" in self.inputs: parent_folder_name = self.inputs.parent_folder_name.value else: parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME - function_data = self.inputs.function_data.get_dict() - # create python script to run the function - script = f""" -import pickle - -# define the function -{function_data["source_code"]} - -# load the inputs from the pickle file -with open('inputs.pickle', 'rb') as handle: - inputs = pickle.load(handle) - -# run the function -result = {function_data["name"]}(**inputs) -# save the result as a pickle file -with open('results.pickle', 'wb') as handle: - pickle.dump(result, handle) -""" - # write the script to the folder + + function_data = self.inputs.function_data + + # Build the Python script + source_code = function_data.get("source_code") + if "pickled_function" in self.inputs.function_data: + pickled_function = self.inputs.function_data.pickled_function.get_serialized_value() + else: + pickled_function = None + # Generate script.py content + function_name = self.get_function_name() # or some user-defined name + script_content = generate_script_py( + pickled_function=pickled_function, + source_code=source_code.value if source_code else None, + function_name=function_name, + ) + + # Write the script to the working folder with folder.open(self.options.input_filename, "w", encoding="utf8") as handle: - handle.write(script) - # symlink = settings.pop('PARENT_FOLDER_SYMLINK', False) - symlink = True + handle.write(script_content) + # Symlink or copy approach for the parent folder + symlink = True remote_copy_list = [] local_copy_list = [] remote_symlink_list = [] remote_list = remote_symlink_list if symlink else remote_copy_list source = self.inputs.get("parent_folder", None) - if source is not None: if isinstance(source, RemoteData): - dirpath = pathlib.Path(source.get_remote_path()) + # Possibly append parent_output_folder path + dirpath_remote = pathlib.Path(source.get_remote_path()) if self.inputs.parent_output_folder is not None: - dirpath = pathlib.Path(source.get_remote_path()) / self.inputs.parent_output_folder.value + dirpath_remote /= self.inputs.parent_output_folder.value remote_list.append( ( source.computer.uuid, - str(dirpath), + str(dirpath_remote), parent_folder_name, ) ) @@ -204,46 +252,56 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: local_copy_list.append((source.uuid, dirname, parent_folder_name)) elif isinstance(source, SinglefileData): local_copy_list.append((source.uuid, source.filename, source.filename)) + + # Upload additional files if "upload_files" in self.inputs: upload_files = self.inputs.upload_files - for key, source in upload_files.items(): + for key, src in upload_files.items(): # replace "_dot_" with "." in the key new_key = key.replace("_dot_", ".") - if isinstance(source, FolderData): - local_copy_list.append((source.uuid, "", new_key)) - elif isinstance(source, SinglefileData): - local_copy_list.append((source.uuid, source.filename, source.filename)) + if isinstance(src, FolderData): + local_copy_list.append((src.uuid, "", new_key)) + elif isinstance(src, SinglefileData): + local_copy_list.append((src.uuid, src.filename, src.filename)) else: raise ValueError( - f"""Input folder/file: {source} is not supported. -Only AiiDA SinglefileData and FolderData are allowed.""" + f"Input file/folder '{key}' of type {type(src)} is not supported. " + "Only AiiDA SinglefileData and FolderData are allowed." ) + + # Copy remote data if any if "copy_files" in self.inputs: copy_files = self.inputs.copy_files - for key, source in copy_files.items(): - # replace "_dot_" with "." in the key + for key, src in copy_files.items(): new_key = key.replace("_dot_", ".") - dirpath = pathlib.Path(source.get_remote_path()) - remote_list.append((source.computer.uuid, str(dirpath), new_key)) - # create pickle file for the inputs + dirpath_remote = pathlib.Path(src.get_remote_path()) + remote_list.append((src.computer.uuid, str(dirpath_remote), new_key)) + + # Create a pickle file for the user input values input_values = {} for key, value in inputs.items(): if isinstance(value, Data) and hasattr(value, "value"): - # get the value of the pickled data input_values[key] = value.value - # TODO: should check this recursively elif isinstance(value, (AttributeDict, dict)): - # if the value is an AttributeDict, use recursively + # Convert an AttributeDict/dict with .value items input_values[key] = {k: v.value for k, v in value.items()} else: raise ValueError( - f"Input data {value} is not supported. Only AiiDA data Node with a value attribute is allowed. " + f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or " + "AttributeDict/dict-of-Data are allowed." ) - # save the value as a pickle file, the path is absolute + filename = "inputs.pickle" - dirpath = pathlib.Path(folder._abspath) with folder.open(filename, "wb") as handle: pickle.dump(input_values, handle) + + # If using a pickled function, we also need to upload the function pickle + if pickled_function: + # create a SinglefileData object for the pickled function + function_pkl_fname = "function.pkl" + with folder.open(function_pkl_fname, "wb") as handle: + handle.write(pickled_function) + # create a singlefiledata object for the pickled data file_data = SinglefileData(file=f"{dirpath}/{filename}") file_data.store() @@ -259,7 +317,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo: calcinfo.local_copy_list = local_copy_list calcinfo.remote_copy_list = remote_copy_list calcinfo.remote_symlink_list = remote_symlink_list - calcinfo.retrieve_list = ["results.pickle", self.options.output_filename] + calcinfo.retrieve_list = ["results.pickle", self.options.output_filename, "_error.json"] if self.inputs.additional_retrieve_list is not None: calcinfo.retrieve_list += self.inputs.additional_retrieve_list.get_list() calcinfo.retrieve_list += self._internal_retrieve_list diff --git a/src/aiida_pythonjob/calculations/utils.py b/src/aiida_pythonjob/calculations/utils.py new file mode 100644 index 0000000..8729aea --- /dev/null +++ b/src/aiida_pythonjob/calculations/utils.py @@ -0,0 +1,108 @@ +from __future__ import annotations + + +def generate_script_py( + pickled_function: bytes | None, source_code: str | None, function_name: str = "user_function" +) -> str: + """ + Generate the script.py content as a single string with robust exception handling. + + :param pickled_function: Serialized function bytes if running in pickled mode, else None. + :param source_code: Raw Python source code if running in source-code mode, else None. + :param function_name: The name of the function to call when running source code mode. + :return: A string representing the entire content of script.py. + """ + # We build a list of lines, then join them with '\n' at the end + script_lines = [ + "import sys", + "import json", + "import traceback", + "", + "def write_error_file(error_type, exc, traceback_str):", + " # Write an error file to disk so the parser can detect the error", + " error_data = {", + " 'error_type': error_type,", + " 'exception_message': str(exc),", + " 'traceback': traceback_str,", + " }", + " with open('_error.json', 'w') as f:", + " json.dump(error_data, f, indent=2)", + "", + "def main():", + " # 1) Attempt to import cloudpickle", + " try:", + " import cloudpickle as pickle", + " except ImportError as e:", + " write_error_file('IMPORT_CLOUDPICKLE_FAILED', e, traceback.format_exc())", + " sys.exit(1)", + "", + " # 2) Attempt to unpickle the inputs", + " try:", + " with open('inputs.pickle', 'rb') as handle:", + " inputs = pickle.load(handle)", + " except Exception as e:", + " write_error_file('UNPICKLE_INPUTS_FAILED', e, traceback.format_exc())", + " sys.exit(1)", + "", + ] + + if pickled_function: + # Mode 1: pickled function + script_lines += [ + " # 3) Attempt to unpickle the function", + " try:", + " with open('function.pkl', 'rb') as f:", + " user_function = pickle.load(f)", + " except Exception as e:", + " write_error_file('UNPICKLE_FUNCTION_FAILED', e, traceback.format_exc())", + " sys.exit(1)", + "", + " # 4) Attempt to run the function", + " try:", + " result = user_function(**inputs)", + " except Exception as e:", + " write_error_file('FUNCTION_EXECUTION_FAILED', e, traceback.format_exc())", + " sys.exit(1)", + ] + elif source_code: + # Mode 2: raw source code + # Indent each line of source_code by 4 spaces to keep correct indentation + source_lines = [f" {line}" for line in source_code.split("\n")] + script_lines += [ + " # 3) Define the function from raw source code", + *source_lines, + "", + " # 4) Attempt to run the function", + " try:", + f" result = {function_name}(**inputs)", + " except Exception as e:", + " write_error_file('FUNCTION_EXECUTION_FAILED', e, traceback.format_exc())", + " sys.exit(1)", + ] + else: + raise ValueError("You must provide exactly one of 'source_code' or 'pickled_function'.") + + # 5) Attempt to pickle (save) the result + script_lines += [ + "", + " # 5) Attempt to pickle the result", + " try:", + " with open('results.pickle', 'wb') as handle:", + " pickle.dump(result, handle)", + " except Exception as e:", + " write_error_file('PICKLE_RESULTS_FAILED', e, traceback.format_exc())", + " sys.exit(1)", + "", + " # If we've made it this far, everything succeeded. Write an empty _error.json", + " # so the parser can always read _error.json (if it's empty, no error).", + " with open('_error.json', 'w') as f:", + " json.dump({}, f, indent=2)", + "", + "if __name__ == '__main__':", + " main()", + "", + ] + + # Join lines with newline + script_content = "\n".join(script_lines) + return script_content diff --git a/src/aiida_pythonjob/data/pickled_data.py b/src/aiida_pythonjob/data/pickled_data.py index e8bbed0..85e7dbe 100644 --- a/src/aiida_pythonjob/data/pickled_data.py +++ b/src/aiida_pythonjob/data/pickled_data.py @@ -58,6 +58,14 @@ def _get_value_from_file(self): "Please ensure that the correct environment and cloudpickle version are being used." ) from e + def get_serialized_value(self): + """Return the serialized value stored in the repository. + + :return: The serialized value. + """ + with self.base.repository.open(self.FILENAME, mode="rb") as f: + return f.read() + def set_value(self, value): """Set the contents of this node by pickling the provided value. diff --git a/src/aiida_pythonjob/launch.py b/src/aiida_pythonjob/launch.py index e759544..5039603 100644 --- a/src/aiida_pythonjob/launch.py +++ b/src/aiida_pythonjob/launch.py @@ -53,32 +53,18 @@ def prepare_pythonjob_inputs( if code is None: command_info = command_info or {} code = get_or_create_code(computer=computer, **command_info) - # get the source code of the function - function_name = function_data["name"] - if function_data.get("is_pickle", False): - function_source_code = ( - function_data["import_statements"] + "\n" + function_data["source_code_without_decorator"] - ) - else: - function_source_code = f"from {function_data['module']} import {function_name}" - # serialize the kwargs into AiiDA Data function_inputs = function_inputs or {} function_inputs = serialize_to_aiida_nodes(function_inputs) - # transfer the args to kwargs + function_data["outputs"] = function_outputs or [{"name": "result"}] inputs = { - "process_label": process_label or "PythonJob<{}>".format(function_name), - "function_data": orm.Dict( - { - "source_code": function_source_code, - "name": function_name, - "outputs": function_outputs or [], - } - ), + "function_data": function_data, "code": code, "function_inputs": function_inputs, "upload_files": new_upload_files, "metadata": metadata or {}, **kwargs, } + if process_label: + inputs[process_label] = process_label return inputs diff --git a/src/aiida_pythonjob/parsers/pythonjob.py b/src/aiida_pythonjob/parsers/pythonjob.py index 26b686a..6222dfe 100644 --- a/src/aiida_pythonjob/parsers/pythonjob.py +++ b/src/aiida_pythonjob/parsers/pythonjob.py @@ -1,84 +1,122 @@ """Parser for an `PythonJob` job.""" +import json + from aiida.engine import ExitCode from aiida.parsers.parser import Parser +# Map error_type from script.py to exit code label +ERROR_TYPE_TO_EXIT_CODE = { + "IMPORT_CLOUDPICKLE_FAILED": "ERROR_IMPORT_CLOUDPICKLE_FAILED", + "UNPICKLE_INPUTS_FAILED": "ERROR_UNPICKLE_INPUTS_FAILED", + "UNPICKLE_FUNCTION_FAILED": "ERROR_UNPICKLE_FUNCTION_FAILED", + "FUNCTION_EXECUTION_FAILED": "ERROR_FUNCTION_EXECUTION_FAILED", + "PICKLE_RESULTS_FAILED": "ERROR_PICKLE_RESULTS_FAILED", +} + class PythonJobParser(Parser): """Parser for an `PythonJob` job.""" def parse(self, **kwargs): - """Parse the contents of the output files stored in the `retrieved` output node. - - The function_outputs could be a namespce, e.g., - function_outputs=[ - {"identifier": "namespace", "name": "add_multiply"}, - {"name": "add_multiply.add"}, - {"name": "add_multiply.multiply"}, - {"name": "minus"}, - ] - """ import pickle - function_outputs = self.node.inputs.function_data.get_dict()["outputs"] - if len(function_outputs) == 0: + # Read function_outputs specification + if "outputs" in self.node.inputs.function_data: + function_outputs = self.node.inputs.function_data.outputs.get_list() + else: function_outputs = [{"name": "result"}] self.output_list = function_outputs - # first we remove nested outputs, e.g., "add_multiply.add" + + # If nested outputs like "add_multiply.add", keep only top-level top_level_output_list = [output for output in self.output_list if "." not in output["name"]] + + # 1) Read _error.json + error_data = {} + try: + with self.retrieved.base.repository.open("_error.json", "r") as ef: + error_data = json.load(ef) + except OSError: + # No _error.json file found + pass + except json.JSONDecodeError as exc: + self.logger.error(f"Error reading _error.json: {exc}") + return self.exit_codes.ERROR_INVALID_OUTPUT # or a different exit code + + # If error_data is non-empty, we have an error from the script + if error_data: + error_type = error_data.get("error_type", "UNKNOWN_ERROR") + exception_message = error_data.get("exception_message", "") + traceback_str = error_data.get("traceback", "") + + # Default to a generic code if we can't match a known error_type + exit_code_label = ERROR_TYPE_TO_EXIT_CODE.get(error_type, "ERROR_SCRIPT_FAILED") + + # Use `.format()` to inject the exception and traceback + return self.exit_codes[exit_code_label].format(exception=exception_message, traceback=traceback_str) + # 2) If we reach here, _error.json exists but is empty or doesn't exist at all -> no error recorded + # Proceed with parsing results.pickle try: with self.retrieved.base.repository.open("results.pickle", "rb") as handle: results = pickle.load(handle) + if isinstance(results, tuple): if len(top_level_output_list) != len(results): return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH for i in range(len(top_level_output_list)): top_level_output_list[i]["value"] = self.serialize_output(results[i], top_level_output_list[i]) + elif isinstance(results, dict): - # pop the exit code if it exists + # pop the exit code if it exists inside the dictionary exit_code = results.pop("exit_code", None) if exit_code: + # If there's an exit_code, handle it (dict or int) if isinstance(exit_code, dict): exit_code = ExitCode(exit_code["status"], exit_code["message"]) elif isinstance(exit_code, int): exit_code = ExitCode(exit_code) if exit_code.status != 0: return exit_code + if len(top_level_output_list) == 1: - # if output name in results, use it + # If output name in results, use it if top_level_output_list[0]["name"] in results: top_level_output_list[0]["value"] = self.serialize_output( results.pop(top_level_output_list[0]["name"]), top_level_output_list[0], ) - # if there are any remaining results, raise an warning + # If there are any extra keys in `results`, log a warning if len(results) > 0: self.logger.warning( f"Found extra results that are not included in the output: {results.keys()}" ) - # otherwise, we assume the results is the output else: + # Otherwise assume the entire dict is the single output top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) elif len(top_level_output_list) > 1: + # Match each top-level output by name for output in top_level_output_list: if output["name"] not in results: if output.get("required", True): return self.exit_codes.ERROR_MISSING_OUTPUT else: output["value"] = self.serialize_output(results.pop(output["name"]), output) - # if there are any remaining results, raise an warning + # Any remaining results are unaccounted for -> log a warning if len(results) > 0: self.logger.warning( f"Found extra results that are not included in the output: {results.keys()}" ) elif len(top_level_output_list) == 1: - # otherwise it returns a single value, we assume the results is the output + # Single top-level output, single result top_level_output_list[0]["value"] = self.serialize_output(results, top_level_output_list[0]) else: return self.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH + + # Store the outputs for output in top_level_output_list: self.out(output["name"], output["value"]) + except OSError: return self.exit_codes.ERROR_READING_OUTPUT_FILE except ValueError as exception: @@ -86,7 +124,7 @@ def parse(self, **kwargs): return self.exit_codes.ERROR_INVALID_OUTPUT def find_output(self, name): - """Find the output with the given name.""" + """Find the output spec with the given name.""" for output in self.output_list: if output["name"] == name: return output @@ -109,6 +147,7 @@ def serialize_output(self, result, output): serialized_result[key] = general_serializer(value) return serialized_result else: - self.exit_codes.ERROR_INVALID_OUTPUT + self.logger.error(f"Expected a dict for namespace '{name}', got {type(result)}.") + return self.exit_codes.ERROR_INVALID_OUTPUT else: return general_serializer(result) diff --git a/src/aiida_pythonjob/utils.py b/src/aiida_pythonjob/utils.py index 495d5ba..34d295d 100644 --- a/src/aiida_pythonjob/utils.py +++ b/src/aiida_pythonjob/utils.py @@ -1,5 +1,4 @@ import inspect -import textwrap from typing import Any, Callable, Dict, List, Optional, Tuple, Union, _SpecialForm, get_type_hints from aiida.common.exceptions import NotExistent @@ -41,57 +40,41 @@ def inspect_function(func: Callable) -> Dict[str, Any]: """Serialize a function for storage or transmission.""" # we need save the source code explicitly, because in the case of jupyter notebook, # the source code is not saved in the pickle file + from aiida_pythonjob.data.pickled_data import PickledData + try: source_code = inspect.getsource(func) + # Split the source into lines for processing + source_code_lines = source_code.split("\n") + source_code = "\n".join(source_code_lines) except OSError: - raise ValueError("Failed to get the source code of the function.") - - # Split the source into lines for processing - source_code_lines = source_code.split("\n") - function_source_code = "\n".join(source_code_lines) - # Find the first line of the actual function definition - for i, line in enumerate(source_code_lines): - if line.strip().startswith("def "): - break - function_source_code_without_decorator = "\n".join(source_code_lines[i:]) - function_source_code_without_decorator = textwrap.dedent(function_source_code_without_decorator) - # we also need to include the necessary imports for the types used in the type hints. - try: - required_imports = get_required_imports(func) - except Exception as exception: - raise ValueError(f"Failed to get the required imports for the function: {exception}") - # Generate import statements - import_statements = "\n".join( - f"from {module} import {', '.join(types)}" for module, types in required_imports.items() - ) - return { - "name": func.__name__, - "source_code": function_source_code, - "source_code_without_decorator": function_source_code_without_decorator, - "import_statements": import_statements, - "is_pickle": True, - } - - -def build_function_data(func): - """Return the executor for this node.""" + source_code = "Failed to retrieve source code." + + return {"source_code": source_code, "mode": "use_pickled_function", "pickled_function": PickledData(value=func)} + + +def build_function_data(func: Callable) -> Dict[str, Any]: + """Inspect the function and return a dictionary with the function data.""" import types if isinstance(func, (types.FunctionType, types.BuiltinFunctionType, type)): # Check if callable is nested (contains dots in __qualname__ after the first segment) + function_data = {"name": func.__name__} if func.__module__ == "__main__" or "." in func.__qualname__.split(".", 1)[-1]: # Local or nested callable, so pickle the callable - executor = inspect_function(func) + function_data.update(inspect_function(func)) else: # Global callable (function/class), store its module and name for reference - executor = { - "module": func.__module__, - "name": func.__name__, - "is_pickle": False, - } + function_data.update( + { + "mode": "use_module_path", + "module_path": func.__module__, + "source_code": f"from {func.__module__} import {func.__name__}", + } + ) else: raise TypeError("Provided object is not a callable function or class.") - return executor + return function_data def get_or_create_code( diff --git a/tests/test_parser.py b/tests/test_parser.py index 512e002..b6dff45 100644 --- a/tests/test_parser.py +++ b/tests/test_parser.py @@ -1,46 +1,55 @@ +from __future__ import annotations + +import json import pathlib import tempfile import cloudpickle as pickle +import pytest from aiida import orm from aiida.cmdline.utils.common import get_workchain_report from aiida.common.links import LinkType -def create_retrieved_folder(result: dict, output_filename="results.pickle"): +def create_retrieved_folder(result: dict, error: dict | None = None, output_filename="results.pickle"): # Create a retrieved ``FolderData`` node with results with tempfile.TemporaryDirectory() as tmpdir: dirpath = pathlib.Path(tmpdir) with open((dirpath / output_filename), "wb") as handle: pickle.dump(result, handle) + error = error or {} + with open((dirpath / "_error.json"), "w") as handle: + json.dump(error, handle) folder_data = orm.FolderData(tree=dirpath.absolute()) return folder_data -def create_process_node(result: dict, function_data: dict, output_filename: str = "results.pickle"): +def create_process_node( + result: dict, function_data: dict, error: dict | None = None, output_filename: str = "results.pickle" +): node = orm.CalcJobNode() node.set_process_type("aiida.calculations:pythonjob.pythonjob") - function_data = orm.Dict(function_data) - retrieved = create_retrieved_folder(result, output_filename=output_filename) - node.base.links.add_incoming(function_data, link_type=LinkType.INPUT_CALC, link_label="function_data") + retrieved = create_retrieved_folder(result, error=error, output_filename=output_filename) + for key, value in function_data.items(): + node.base.links.add_incoming(value, link_type=LinkType.INPUT_CALC, link_label=f"function_data__{key}") + value.store() retrieved.base.links.add_incoming(node, link_type=LinkType.CREATE, link_label="retrieved") - function_data.store() node.store() retrieved.store() return node -def create_parser(result, function_data, output_filename="results.pickle"): +def create_parser(result, function_data, error: dict | None = None, output_filename: str = "results.pickle"): from aiida_pythonjob.parsers import PythonJobParser - node = create_process_node(result, function_data, output_filename=output_filename) + node = create_process_node(result, function_data, error=error, output_filename=output_filename) parser = PythonJobParser(node=node) return parser def test_tuple_result(fixture_localhost): result = (1, 2, 3) - function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + function_data = {"outputs": orm.List([{"name": "a"}, {"name": "b"}, {"name": "c"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code is None @@ -49,7 +58,7 @@ def test_tuple_result(fixture_localhost): def test_tuple_result_mismatch(fixture_localhost): result = (1, 2) - function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + function_data = {"outputs": orm.List([{"name": "a"}, {"name": "b"}, {"name": "c"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code == parser.exit_codes.ERROR_RESULT_OUTPUT_MISMATCH @@ -57,7 +66,7 @@ def test_tuple_result_mismatch(fixture_localhost): def test_dict_result(fixture_localhost): result = {"a": 1, "b": 2, "c": 3} - function_data = {"outputs": [{"name": "a"}, {"name": "b"}]} + function_data = {"outputs": orm.List([{"name": "a"}, {"name": "b"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code is None @@ -68,7 +77,7 @@ def test_dict_result(fixture_localhost): def test_dict_result_missing(fixture_localhost): result = {"a": 1, "b": 2} - function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + function_data = {"outputs": orm.List([{"name": "a"}, {"name": "b"}, {"name": "c"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code == parser.exit_codes.ERROR_MISSING_OUTPUT @@ -76,7 +85,7 @@ def test_dict_result_missing(fixture_localhost): def test_dict_result_as_one_output(fixture_localhost): result = {"a": 1, "b": 2, "c": 3} - function_data = {"outputs": [{"name": "result"}]} + function_data = {"outputs": orm.List([{"name": "result"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code is None @@ -86,7 +95,7 @@ def test_dict_result_as_one_output(fixture_localhost): def test_dict_result_only_show_one_output(fixture_localhost): result = {"a": 1, "b": 2} - function_data = {"outputs": [{"name": "a"}]} + function_data = {"outputs": orm.List([{"name": "a"}])} parser = create_parser(result, function_data) parser.parse() assert len(parser.outputs) == 1 @@ -97,14 +106,14 @@ def test_dict_result_only_show_one_output(fixture_localhost): def test_exit_code(fixture_localhost): result = {"a": 1, "exit_code": {"status": 0, "message": ""}} - function_data = {"outputs": [{"name": "a"}]} + function_data = {"outputs": orm.List([{"name": "a"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code is None assert parser.outputs["a"] == 1 # result = {"exit_code": {"status": 1, "message": "error"}} - function_data = {"outputs": [{"name": "a"}, {"name": "b"}, {"name": "c"}]} + function_data = {"outputs": orm.List([{"name": "a"}, {"name": "b"}, {"name": "c"}])} parser = create_parser(result, function_data) exit_code = parser.parse() assert exit_code is not None @@ -114,7 +123,27 @@ def test_exit_code(fixture_localhost): def test_no_output_file(fixture_localhost): result = {"a": 1, "b": 2, "c": 3} - function_data = {"outputs": [{"name": "result"}]} + function_data = {"outputs": orm.List([{"name": "result"}])} parser = create_parser(result, function_data, output_filename="not_results.pickle") exit_code = parser.parse() assert exit_code == parser.exit_codes.ERROR_READING_OUTPUT_FILE + + +@pytest.mark.parametrize( + "error_type, status", + [ + ("IMPORT_CLOUDPICKLE_FAILED", 322), + ("UNPICKLE_INPUTS_FAILED", 323), + ("UNPICKLE_FUNCTION_FAILED", 324), + ("FUNCTION_EXECUTION_FAILED", 325), + ("PICKLE_RESULTS_FAILED", 326), + ], +) +def test_run_script_error(error_type, status): + error = {"error_type": error_type, "exception_message": "error", "traceback": "traceback"} + result = {"a": 1, "exit_code": {"status": 0, "message": ""}} + function_data = {"outputs": orm.List([{"name": "a"}])} + parser = create_parser(result, function_data, error=error) + exit_code = parser.parse() + assert exit_code is not None + assert exit_code.status == status diff --git a/tests/test_pythonjob.py b/tests/test_pythonjob.py index 130c535..9c0ea55 100644 --- a/tests/test_pythonjob.py +++ b/tests/test_pythonjob.py @@ -19,7 +19,7 @@ def add(x, y): with pytest.raises(ValueError, match="Only one of function or function_data should be provided"): prepare_pythonjob_inputs( function=add, - function_data={"module": "math", "name": "sqrt", "is_pickle": False}, + function_data={"module_path": "math", "name": "sqrt", "is_pickle": False}, ) @@ -53,7 +53,6 @@ def add(x, y): {"name": "diff"}, ], ) - inputs.pop("process_label") result, node = run_get_node(PythonJob, **inputs) assert result["sum"].value == 3 @@ -295,3 +294,18 @@ def add(x: array, y: array) -> array: result, node = run_get_node(PythonJob, inputs=inputs) assert node.exit_status == 410 assert node.exit_message == "Some elements are negative" + + +def test_local_function(fixture_localhost): + def multily(x, y): + return x * y + + def add(x, y): + return x + multily(x, y) + + inputs = prepare_pythonjob_inputs( + add, + function_inputs={"x": 2, "y": 3}, + ) + result, node = run_get_node(PythonJob, **inputs) + assert result["result"].value == 8 diff --git a/tests/test_utils.py b/tests/test_utils.py index 508d741..d232805 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,7 +5,12 @@ def test_build_function_data(): from math import sqrt function_data = build_function_data(sqrt) - assert function_data == {"module": "math", "name": "sqrt", "is_pickle": False} + assert function_data == { + "name": "sqrt", + "mode": "use_module_path", + "module_path": "math", + "source_code": "from math import sqrt", + } # try: function_data = build_function_data(1)