diff --git a/aiida_submission_controller/base.py b/aiida_submission_controller/base.py index 3a51bcb..c108147 100644 --- a/aiida_submission_controller/base.py +++ b/aiida_submission_controller/base.py @@ -6,7 +6,7 @@ from aiida import engine, orm from aiida.common import NotExistent -from pydantic import BaseModel, validator +from pydantic import BaseModel, field_validator from rich import print from rich.console import Console from rich.table import Table @@ -56,12 +56,16 @@ class BaseSubmissionController(BaseModel): unique_extra_keys: Optional[tuple] = None """Tuple of keys defined in the extras that uniquely define each process to be run.""" - _validate_group_exists = validator("group_label", allow_reuse=True)(validate_group_exists) + @field_validator('group_label') + @classmethod + def validate_group_exists(cls, value: str) -> str: + """Validator that makes sure the ``Group`` with the provided label exists.""" + return validate_group_exists(value) @property def group(self): """Return the AiiDA ORM Group instance that is managed by this class.""" - return orm.Group.objects.get(label=self.group_label) + return orm.Group.collection.get(label=self.group_label) def get_query(self, process_projections, only_active=False): """Return a QueryBuilder object to get all processes in the group associated to this. @@ -236,7 +240,7 @@ def submit_new_batch(self, dry_run=False, sort=False, verbose=False): else: CMDLINE_LOGGER.report(f"Submitted work chain <{wc_node}> for extras <{workchain_extras}>.") - wc_node.set_extra_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras)) + wc_node.base.extras.set_many(get_extras_dict(self.get_extra_unique_keys(), workchain_extras)) self.group.add_nodes([wc_node]) submitted[workchain_extras] = wc_node diff --git a/aiida_submission_controller/from_group.py b/aiida_submission_controller/from_group.py index 1eea6a2..368e0d3 100644 --- a/aiida_submission_controller/from_group.py +++ b/aiida_submission_controller/from_group.py @@ -3,7 +3,7 @@ from typing import Optional from aiida import orm -from pydantic import validator +from pydantic import field_validator from .base import BaseSubmissionController, validate_group_exists @@ -22,12 +22,16 @@ class FromGroupSubmissionController(BaseSubmissionController): # pylint: disabl order_by: Optional[dict] = None """Ordering applied to the query of the nodes in the parent group.""" - _validate_group_exists = validator("parent_group_label", allow_reuse=True)(validate_group_exists) + @field_validator('parent_group_label') + @classmethod + def validate_parent_group_exists(cls, value: str) -> str: + """Validator that makes sure the parent ``Group`` with the provided label exists.""" + return validate_group_exists(value) @property def parent_group(self): """Return the AiiDA ORM Group instance of the parent group.""" - return orm.Group.objects.get(label=self.parent_group_label) + return orm.Group.collection.get(label=self.parent_group_label) def get_parent_node_from_extras(self, extras_values): """Return the Node instance (in the parent group) from the (unique) extras identifying it.""" diff --git a/examples/add_in_batches.py b/examples/add_in_batches.py index 7b362c6..c9ce9e2 100644 --- a/examples/add_in_batches.py +++ b/examples/add_in_batches.py @@ -4,7 +4,7 @@ from aiida import orm from aiida.calculations.arithmetic.add import ArithmeticAddCalculation -from pydantic import validator +from pydantic import field_validator from aiida_submission_controller import BaseSubmissionController @@ -15,7 +15,8 @@ class AdditionTableSubmissionController(BaseSubmissionController): code_label: str """Label of the `code.arithmetic.add` `Code`.""" - @validator("code_label") + @field_validator("code_label") + @classmethod def _check_code_plugin(cls, value): plugin_type = orm.load_code(value).default_calc_job_plugin if plugin_type == "core.arithmetic.add": @@ -64,7 +65,7 @@ def main(): ## verdi code setup -L add --on-computer --computer=localhost -P core.arithmetic.add --remote-abs-path=/bin/bash -n # Create a controller - group, _ = orm.Group.objects.get_or_create(label="tests/addition_table") + group, _ = orm.Group.collection.get_or_create(label="tests/addition_table") controller = AdditionTableSubmissionController( code_label="add@localhost", diff --git a/pyproject.toml b/pyproject.toml index 0b85b55..ba8a44e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,8 +27,8 @@ classifiers = [ requires-python = ">=3.6" dependencies = [ - "aiida-core>=1.0", - "pydantic~=1.10.4", + "aiida-core>=2.0", + "pydantic~=2.7", "rich", ]