Skip to content

Commit

Permalink
Using a pickled function object
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Dec 14, 2024
1 parent 1cfe56f commit fa630d3
Show file tree
Hide file tree
Showing 9 changed files with 400 additions and 170 deletions.
202 changes: 130 additions & 72 deletions src/aiida_pythonjob/calculations/pythonjob.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Calcjob to run a Python function on a remote computer."""
"""Calcjob to run a Python function on a remote computer, either via raw source code or a pickled function."""

from __future__ import annotations

Expand All @@ -11,7 +11,6 @@
from aiida.engine import CalcJob, CalcJobProcessSpec
from aiida.orm import (
Data,
Dict,
FolderData,
List,
RemoteData,
Expand All @@ -24,7 +23,12 @@


class PythonJob(CalcJob):
"""Calcjob to run a Python function on a remote computer."""
"""Calcjob to run a Python function on a remote computer.
Supports two modes:
1) Loading a pickled function object (function_data.pickled_function).
2) Embedding raw source code for the function (function_data.source_code).
"""

_internal_retrieve_list = []
_retrieve_singlefile_list = []
Expand All @@ -36,16 +40,16 @@ class PythonJob(CalcJob):

@classmethod
def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
"""Define the process specification, including its inputs, outputs and known exit codes.
:param spec: the calculation job process spec to define.
"""
"""Define the process specification, including its inputs, outputs and known exit codes."""
super().define(spec)
spec.input("function_data", valid_type=Dict, serializer=to_aiida_type, required=False)
spec.input_namespace("function_data")
spec.input("function_data.name", valid_type=Str, serializer=to_aiida_type)
spec.input("function_data.source_code", valid_type=Str, serializer=to_aiida_type, required=False)
spec.input("function_data.outputs", valid_type=List, serializer=to_aiida_type, required=False)
spec.input("function_data.pickled_function", valid_type=Data, required=False)
spec.input("function_data.mode", valid_type=Str, serializer=to_aiida_type, required=False)
spec.input("process_label", valid_type=Str, serializer=to_aiida_type, required=False)
spec.input_namespace(
"function_inputs", valid_type=Data, required=False
) # , serializer=serialize_to_aiida_nodes)
spec.input_namespace("function_inputs", valid_type=Data, required=False)
spec.input(
"parent_folder",
valid_type=(RemoteData, FolderData, SinglefileData),
Expand All @@ -57,8 +61,8 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
valid_type=Str,
required=False,
serializer=to_aiida_type,
help="""Default name of the subfolder that you want to create in the working directory,
in which you want to place the files taken from parent_folder""",
help="""Default name of the subfolder to create in the working directory
where the files from parent_folder are placed.""",
)
spec.input(
"parent_output_folder",
Expand Down Expand Up @@ -86,7 +90,7 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
default=None,
required=False,
serializer=to_aiida_type,
help="The names of the files to retrieve",
help="Additional filenames to retrieve from the remote work directory",
)
spec.outputs.dynamic = True
# set default options (optional)
Expand All @@ -97,7 +101,6 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
"num_machines": 1,
"num_mpiprocs_per_machine": 1,
}
# start exit codes - marker for docs
spec.exit_code(
310,
"ERROR_READING_OUTPUT_FILE",
Expand All @@ -116,86 +119,131 @@ def define(cls, spec: CalcJobProcessSpec) -> None: # type: ignore[override]
invalidates_cache=True,
message="The number of results does not match the number of outputs.",
)
spec.exit_code(
322,
"ERROR_IMPORT_CLOUDPICKLE_FAILED",
invalidates_cache=True,
message="Importing cloudpickle failed.\n{exception}\n{traceback}",
)
spec.exit_code(
323,
"ERROR_UNPICKLE_INPUTS_FAILED",
invalidates_cache=True,
message="Failed to unpickle inputs.\n{exception}\n{traceback}",
)
spec.exit_code(
324,
"ERROR_UNPICKLE_FUNCTION_FAILED",
invalidates_cache=True,
message="Failed to unpickle user function.\n{exception}\n{traceback}",
)
spec.exit_code(
325,
"ERROR_FUNCTION_EXECUTION_FAILED",
invalidates_cache=True,
message="Function execution failed.\n{exception}\n{traceback}",
)
spec.exit_code(
326,
"ERROR_PICKLE_RESULTS_FAILED",
invalidates_cache=True,
message="Failed to pickle results.\n{exception}\n{traceback}",
)
spec.exit_code(
327,
"ERROR_SCRIPT_FAILED",
invalidates_cache=True,
message="The script failed for an unknown reason.\n{exception}\n{traceback}",
)

def _build_process_label(self) -> str:
"""Use the function name as the process label.
def get_function_name(self) -> str:
"""Return the name of the function to run."""
if "name" in self.inputs.function_data:
name = self.inputs.function_data.name.value
else:
try:
name = self.inputs.function_data.pickled_function.value.__name__
except AttributeError:
# If a user doesn't specify name, fallback to something generic
name = "anonymous_function"
return name

:returns: The process label to use for ``ProcessNode`` instances.
"""
def _build_process_label(self) -> str:
"""Use the function name or an explicit label as the process label."""
if "process_label" in self.inputs:
return self.inputs.process_label.value
else:
data = self.inputs.function_data.get_dict()
return f"PythonJob<{data['name']}>"
name = self.get_function_name()
return f"PythonJob<{name}>"

def on_create(self) -> None:
"""Called when a Process is created."""

super().on_create()
self.node.label = self._build_process_label()

def prepare_for_submission(self, folder: Folder) -> CalcInfo:
"""Prepare the calculation for submission.
1) Write the python script to the folder.
1) Write the python script to the folder, depending on the mode (source vs. pickled function).
2) Write the inputs to a pickle file and save it to the folder.
:param folder: A temporary folder on the local file system.
:returns: A :class:`aiida.common.datastructures.CalcInfo` instance.
"""
import cloudpickle as pickle

from aiida_pythonjob.calculations.utils import generate_script_py

dirpath = pathlib.Path(folder._abspath)
inputs: dict[str, t.Any]

# Prepare the dictionary of input arguments for the function
inputs: dict[str, t.Any]
if self.inputs.function_inputs:
inputs = dict(self.inputs.function_inputs)
else:
inputs = {}

# Prepare the final subfolder name for the parent folder
if "parent_folder_name" in self.inputs:
parent_folder_name = self.inputs.parent_folder_name.value
else:
parent_folder_name = self._DEFAULT_PARENT_FOLDER_NAME
function_data = self.inputs.function_data.get_dict()
# create python script to run the function
script = f"""
import pickle
# define the function
{function_data["source_code"]}
# load the inputs from the pickle file
with open('inputs.pickle', 'rb') as handle:
inputs = pickle.load(handle)
# run the function
result = {function_data["name"]}(**inputs)
# save the result as a pickle file
with open('results.pickle', 'wb') as handle:
pickle.dump(result, handle)
"""
# write the script to the folder

function_data = self.inputs.function_data

# Build the Python script
source_code = function_data.get("source_code")
if "pickled_function" in self.inputs.function_data:
pickled_function = self.inputs.function_data.pickled_function.get_serialized_value()
else:
pickled_function = None
# Generate script.py content
function_name = self.get_function_name() # or some user-defined name
script_content = generate_script_py(
pickled_function=pickled_function,
source_code=source_code.value if source_code else None,
function_name=function_name,
)

# Write the script to the working folder
with folder.open(self.options.input_filename, "w", encoding="utf8") as handle:
handle.write(script)
# symlink = settings.pop('PARENT_FOLDER_SYMLINK', False)
symlink = True
handle.write(script_content)

# Symlink or copy approach for the parent folder
symlink = True
remote_copy_list = []
local_copy_list = []
remote_symlink_list = []
remote_list = remote_symlink_list if symlink else remote_copy_list

source = self.inputs.get("parent_folder", None)

if source is not None:
if isinstance(source, RemoteData):
dirpath = pathlib.Path(source.get_remote_path())
# Possibly append parent_output_folder path
dirpath_remote = pathlib.Path(source.get_remote_path())
if self.inputs.parent_output_folder is not None:
dirpath = pathlib.Path(source.get_remote_path()) / self.inputs.parent_output_folder.value
dirpath_remote /= self.inputs.parent_output_folder.value
remote_list.append(
(
source.computer.uuid,
str(dirpath),
str(dirpath_remote),
parent_folder_name,
)
)
Expand All @@ -204,46 +252,56 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
local_copy_list.append((source.uuid, dirname, parent_folder_name))
elif isinstance(source, SinglefileData):
local_copy_list.append((source.uuid, source.filename, source.filename))

# Upload additional files
if "upload_files" in self.inputs:
upload_files = self.inputs.upload_files
for key, source in upload_files.items():
for key, src in upload_files.items():
# replace "_dot_" with "." in the key
new_key = key.replace("_dot_", ".")
if isinstance(source, FolderData):
local_copy_list.append((source.uuid, "", new_key))
elif isinstance(source, SinglefileData):
local_copy_list.append((source.uuid, source.filename, source.filename))
if isinstance(src, FolderData):
local_copy_list.append((src.uuid, "", new_key))
elif isinstance(src, SinglefileData):
local_copy_list.append((src.uuid, src.filename, src.filename))
else:
raise ValueError(
f"""Input folder/file: {source} is not supported.
Only AiiDA SinglefileData and FolderData are allowed."""
f"Input file/folder '{key}' of type {type(src)} is not supported. "
"Only AiiDA SinglefileData and FolderData are allowed."
)

# Copy remote data if any
if "copy_files" in self.inputs:
copy_files = self.inputs.copy_files
for key, source in copy_files.items():
# replace "_dot_" with "." in the key
for key, src in copy_files.items():
new_key = key.replace("_dot_", ".")
dirpath = pathlib.Path(source.get_remote_path())
remote_list.append((source.computer.uuid, str(dirpath), new_key))
# create pickle file for the inputs
dirpath_remote = pathlib.Path(src.get_remote_path())
remote_list.append((src.computer.uuid, str(dirpath_remote), new_key))

# Create a pickle file for the user input values
input_values = {}
for key, value in inputs.items():
if isinstance(value, Data) and hasattr(value, "value"):
# get the value of the pickled data
input_values[key] = value.value
# TODO: should check this recursively
elif isinstance(value, (AttributeDict, dict)):
# if the value is an AttributeDict, use recursively
# Convert an AttributeDict/dict with .value items
input_values[key] = {k: v.value for k, v in value.items()}
else:
raise ValueError(
f"Input data {value} is not supported. Only AiiDA data Node with a value attribute is allowed. "
f"Input data {value} is not supported. Only AiiDA Data nodes with a '.value' or "
"AttributeDict/dict-of-Data are allowed."
)
# save the value as a pickle file, the path is absolute

filename = "inputs.pickle"
dirpath = pathlib.Path(folder._abspath)
with folder.open(filename, "wb") as handle:
pickle.dump(input_values, handle)

# If using a pickled function, we also need to upload the function pickle
if pickled_function:
# create a SinglefileData object for the pickled function
function_pkl_fname = "function.pkl"
with folder.open(function_pkl_fname, "wb") as handle:
handle.write(pickled_function)

# create a singlefiledata object for the pickled data
file_data = SinglefileData(file=f"{dirpath}/{filename}")
file_data.store()
Expand All @@ -259,7 +317,7 @@ def prepare_for_submission(self, folder: Folder) -> CalcInfo:
calcinfo.local_copy_list = local_copy_list
calcinfo.remote_copy_list = remote_copy_list
calcinfo.remote_symlink_list = remote_symlink_list
calcinfo.retrieve_list = ["results.pickle", self.options.output_filename]
calcinfo.retrieve_list = ["results.pickle", self.options.output_filename, "_error.json"]
if self.inputs.additional_retrieve_list is not None:
calcinfo.retrieve_list += self.inputs.additional_retrieve_list.get_list()
calcinfo.retrieve_list += self._internal_retrieve_list
Expand Down
Loading

0 comments on commit fa630d3

Please sign in to comment.