Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FromGroupSubmissionController: enable multiple submissions per parent node #21

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions aiida_submission_controller/from_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from aiida import orm
from pydantic import validator
from pydantic import Field, PrivateAttr, validator

from .base import BaseSubmissionController, validate_group_exists

Expand All @@ -15,36 +15,42 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl
and define the abstract methods.
"""

dynamic_extra: dict = Field(default_factory=dict)
"""A dictionary of dynamic extras to be added to the extras of the process."""
parent_group_label: str
"""Label of the parent group from which to construct the process inputs."""
filters: Optional[dict] = None
"""Filters applied to the query of the nodes in the parent group."""
order_by: Optional[dict] = None
"""Ordering applied to the query of the nodes in the parent group."""

_dynamic_extra_keys: tuple = PrivateAttr(default_factory=tuple)
_dynamic_extra_values: tuple = PrivateAttr(default_factory=tuple)

_validate_group_exists = validator("parent_group_label", allow_reuse=True)(validate_group_exists)

def __init__(self, **kwargs):
"""Initialize the instance."""
super().__init__(**kwargs)

if self.dynamic_extra:
self._dynamic_extra_keys, self._dynamic_extra_values = zip(*self.dynamic_extra.items())

@property
def parent_group(self):
"""Return the AiiDA ORM Group instance of the parent group."""
return orm.Group.objects.get(label=self.parent_group_label)

def get_parent_node_from_extras(self, extras_values):
"""Return the Node instance (in the parent group) from the (unique) extras identifying it."""
extras_projections = self.get_process_extra_projections()
assert len(extras_values) == len(extras_projections), f"The extras must be of length {len(extras_projections)}"
filters = dict(zip(extras_projections, extras_values))
def get_extra_unique_keys(self):
"""Return a tuple of the keys of the unique extras that will be used to uniquely identify your workchains."""
# `_parent_uuid` will be replaced by the `uuid` attribute in the queries
unique_extra_keys = self.unique_extra_keys or []
combined_extras = ["_parent_uuid"] + list(unique_extra_keys) + list(self._dynamic_extra_keys)
return tuple(combined_extras)

qbuild = orm.QueryBuilder()
qbuild.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group")
qbuild.append(orm.Node, project="*", filters=filters, tag="process", with_group="group")
qbuild.limit(2)
results = qbuild.all(flat=True)
if len(results) != 1:
raise ValueError(
"I would have expected only 1 result for extras={extras}, I found {'>1' if len(qbuild) else '0'}"
)
return results[0]
def get_parent_node_from_extras(self, extras_values):
"""Return the Node instance (in the parent group) from the `uuid` identifying it."""
return orm.load_node(extras_values[0])

def get_all_extras_to_submit(self):
"""Return a *set* of the values of all extras uniquely identifying all simulations that you want to submit.
Expand All @@ -57,11 +63,15 @@ def get_all_extras_to_submit(self):
"""
extras_projections = self.get_process_extra_projections()

# Use only the unique extras (and the parent uuid) to identify the processes to be submitted
if self._dynamic_extra_keys:
extras_projections = extras_projections[: -len(self._dynamic_extra_keys)]

qbuild = orm.QueryBuilder()
qbuild.append(orm.Group, filters={"id": self.parent_group.pk}, tag="group")
qbuild.append(
orm.Node,
project=extras_projections,
project=["uuid"] + extras_projections[1:], # Replace `_parent_uuid` with `uuid`
filters=self.filters,
tag="process",
with_group="group",
Expand All @@ -76,10 +86,12 @@ def get_all_extras_to_submit(self):
# First, however, convert to a list of tuples otherwise
# the inner lists are not hashable
results = [tuple(_) for _ in results]
for res in results:
for i, res in enumerate(results):
assert all(
extra is not None for extra in res
), "There is at least one of the nodes in the parent group that does not define one of the required extras."
results[i] = (*res, *self._dynamic_extra_values) # Add the dynamic extras to the results

results_set = set(results)

assert len(results) == len(results_set), "There are duplicate extras in the parent group"
Expand Down
Loading