Skip to content

Commit

Permalink
TMP commit for transfer
Browse files Browse the repository at this point in the history
  • Loading branch information
mbercx committed Jul 16, 2024
1 parent f3f6dac commit 0c940e1
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 12 deletions.
12 changes: 8 additions & 4 deletions aiida_submission_controller/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from aiida import engine, orm
from aiida.common import NotExistent
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator
from rich import print
from rich.console import Console
from rich.table import Table
Expand Down Expand Up @@ -56,12 +56,16 @@ class BaseSubmissionController(BaseModel):
unique_extra_keys: Optional[tuple] = None
"""Tuple of keys defined in the extras that uniquely define each process to be run."""

_validate_group_exists = validator("group_label", allow_reuse=True)(validate_group_exists)
@field_validator('group_label')
@classmethod
def validate_group_exists(cls, value: str) -> str:
"""Validator that makes sure the ``Group`` with the provided label exists."""
return validate_group_exists(value)

@property
def group(self):
"""Return the AiiDA ORM Group instance that is managed by this class."""
return orm.Group.objects.get(label=self.group_label)
return orm.Group.collection.get(label=self.group_label)

def get_query(self, process_projections, only_active=False):
"""Return a QueryBuilder object to get all processes in the group associated to this.
Expand Down Expand Up @@ -236,7 +240,7 @@ def submit_new_batch(self, dry_run=False, sort=False, verbose=False):
else:
CMDLINE_LOGGER.report(f"Submitted work chain <{wc_node}> for extras <{workchain_extras}>.")

wc_node.set_extra_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras))
wc_node.base.extras.set_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras))
self.group.add_nodes([wc_node])
submitted[workchain_extras] = wc_node

Expand Down
10 changes: 7 additions & 3 deletions aiida_submission_controller/from_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

from aiida import orm
from pydantic import validator
from pydantic import field_validator

from .base import BaseSubmissionController, validate_group_exists

Expand All @@ -22,12 +22,16 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl
order_by: Optional[dict] = None
"""Ordering applied to the query of the nodes in the parent group."""

_validate_group_exists = validator("parent_group_label", allow_reuse=True)(validate_group_exists)
@field_validator('parent_group_label')
@classmethod
def validate_parent_group_exists(cls, value: str) -> str:
"""Validator that makes sure the parent ``Group`` with the provided label exists."""
return validate_group_exists(value)

@property
def parent_group(self):
"""Return the AiiDA ORM Group instance of the parent group."""
return orm.Group.objects.get(label=self.parent_group_label)
return orm.Group.collection.get(label=self.parent_group_label)

def get_parent_node_from_extras(self, extras_values):
"""Return the Node instance (in the parent group) from the (unique) extras identifying it."""
Expand Down
7 changes: 4 additions & 3 deletions examples/add_in_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from aiida import orm
from aiida.calculations.arithmetic.add import ArithmeticAddCalculation
from pydantic import validator
from pydantic import field_validator

from aiida_submission_controller import BaseSubmissionController

Expand All @@ -15,7 +15,8 @@ class AdditionTableSubmissionController(BaseSubmissionController):
code_label: str
"""Label of the `code.arithmetic.add` `Code`."""

@validator("code_label")
@field_validator("code_label")
@classmethod
def _check_code_plugin(cls, value):
plugin_type = orm.load_code(value).default_calc_job_plugin
if plugin_type == "core.arithmetic.add":
Expand Down Expand Up @@ -64,7 +65,7 @@ def main():
## verdi code setup -L add --on-computer --computer=localhost -P core.arithmetic.add --remote-abs-path=/bin/bash -n
# Create a controller

group, _ = orm.Group.objects.get_or_create(label="tests/addition_table")
group, _ = orm.Group.collection.get_or_create(label="tests/addition_table")

controller = AdditionTableSubmissionController(
code_label="add@localhost",
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ classifiers = [
requires-python = ">=3.6"

dependencies = [
"aiida-core>=1.0",
"pydantic~=1.10.4",
"aiida-core>=2.0",
"pydantic~=2.7",
"rich",
]

Expand Down

0 comments on commit 0c940e1

Please sign in to comment.