Skip to content

Commit

Permalink
Accept UpfData to contruct builder
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed May 14, 2024
1 parent 49a9b17 commit 3770a2f
Show file tree
Hide file tree
Showing 8 changed files with 194 additions and 48 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,6 @@ filterwarnings = [
'ignore:Object of type .* not in session, .* operation along .* will not proceed:sqlalchemy.exc.SAWarning',
'ignore:The `Code` class is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning',
]
markers = [
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
]
80 changes: 39 additions & 41 deletions src/aiida_sssp_workflow/protocol/convergence.yml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
---
acwf:
name: acwf
description: The protocol where input parameters bring from AiiDA common workflow (ACWF).
balanced:
name: balanced
description: The balanced protocol from Gabriel

base: # base parameters is inherit by other process
occupations: smearing
degauss: 0.02 # balanced protocol of qe -> gabriel
smearing: fd
conv_thr_per_atom: 1.0e-10
kpoints_distance: 0.15 # balanced protocol of qe -> gabriel
conv_thr_per_atom: 1.0e-8
kpoints_distance: 0.2 # balanced protocol of qe -> gabriel
mixing_beta: 0.4

cohesive_energy:
Expand All @@ -35,43 +35,41 @@ acwf:
scale_count: 7
scale_increment: 0.02

acwf:
name: acwf
description: The parameters of EOS is exactly the same as it used in nat.phys.rev 2024 paper.

base: # base parameters is inherit by other process
occupations: smearing
degauss: 0.0045 # balanced protocol of qe -> gabriel
smearing: fd
conv_thr_per_atom: 1.0e-10
kpoints_distance: 0.06 # balanced protocol of qe -> gabriel
mixing_beta: 0.4

cohesive_energy:
atom_smearing: gaussian
vacuum_length: 12.0

phonon_frequencies:
qpoints_list:
- [0.5, 0.5, 0.5]
epsilon: false
tr2_ph: 1.0e-16

pressure:
scale_count: 7
scale_increment: 0.02

bands:
init_nbands_factor: 3.0
fermi_shift: 10.0
kpoints_distance_scf: 0.15
kpoints_distance_bands: 0.15

#moderate:
# name: moderate
# description: The protocol where input parameters bring from aiidaqe moderate protocol. Only for QE >= 6.8
#
# base: # base parameters is inherit by other process
# occupations: smearing
# degauss: 0.01
# smearing: cold
# conv_thr_per_atom: 1.0e-10
# kpoints_distance: 0.15
#
# cohesive_energy:
# atom_smearing: gaussian
# vacuum_length: 12.0
#
# phonon_frequencies:
# qpoints_list:
# - [0.5, 0.5, 0.5]
# epsilon: false
# tr2_ph: 1.0e-16
#
# pressure:
# scale_count: 7
# scale_increment: 0.02
# mixing_beta: 0.4
#
# bands:
# init_nbands_factor: 3.0
# fermi_shift: 10.0
# kpoints_distance_scf: 0.15
# kpoints_distance_bands: 0.15
#
# eos:
# scale_count: 7
# scale_increment: 0.02
# mixing_beta: 0.4
eos:
scale_count: 7
scale_increment: 0.02

test:
name: test-only
Expand Down
13 changes: 8 additions & 5 deletions src/aiida_sssp_workflow/workflows/convergence/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
The detail parameters for different properties are defined in the subclass that inherit this base class.
"""

from typing import Union
from pathlib import Path
from abc import ABCMeta, abstractmethod

from aiida import orm
from aiida.engine import append_
from aiida.plugins import DataFactory
from aiida.engine import ProcessBuilder
from aiida_pseudo.data.pseudo import UpfData

from aiida_sssp_workflow.utils import (
get_default_configuration,
Expand All @@ -24,8 +25,6 @@
from aiida_sssp_workflow.workflows import SelfCleanWorkChain
from aiida_sssp_workflow.workflows.convergence.report import ConvergenceReport

UpfData = DataFactory("pseudo.upf")


class abstract_attribute(object):
"""lazy variable check: https://stackoverflow.com/a/32536493"""
Expand Down Expand Up @@ -210,7 +209,7 @@ def prepare_evaluate_builder(self, ecutwfc: int, ecutrho: int) -> dict:
@classmethod
def get_builder(
cls,
pseudo: Path,
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
Expand All @@ -219,7 +218,11 @@ def get_builder(
"""Generate builder for the generic convergence workflow"""
builder = super().get_builder()
builder.protocol = orm.Str(protocol)
builder.pseudo = UpfData.get_or_create(pseudo)

if isinstance(pseudo, Path):
builder.pseudo = UpfData.get_or_create(pseudo)
else:
builder.pseudo = pseudo

if ret := is_valid_cutoff_list(cutoff_list):
raise ValueError(ret)
Expand Down
4 changes: 3 additions & 1 deletion src/aiida_sssp_workflow/workflows/convergence/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
"""

from pathlib import Path
from typing import Union

from aiida import orm
from aiida.engine import ProcessBuilder
from aiida_pseudo.data.pseudo import UpfData

from aiida_sssp_workflow.utils import get_default_mpi_options
from aiida_sssp_workflow.workflows.convergence._base import _BaseConvergenceWorkChain
Expand Down Expand Up @@ -45,7 +47,7 @@ def define(cls, spec):
@classmethod
def get_builder(
cls,
pseudo: Path,
pseudo: Union[Path, UpfData],
protocol: str,
cutoff_list: list,
configuration: str,
Expand Down
75 changes: 74 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from aiida import orm
from aiida.orm.utils.managers import NodeLinksManager
from aiida.engine import ProcessBuilder

pytest_plugins = ["aiida.manage.tests.pytest_fixtures"]

Expand Down Expand Up @@ -62,12 +63,69 @@ def _pseudo_path(element="Al"):
return _pseudo_path


def _serialize_data(data):
from aiida.orm import (
AbstractCode,
BaseType,
Data,
Dict,
KpointsData,
List,
RemoteData,
SinglefileData,
)
from aiida.plugins import DataFactory

StructureData = DataFactory("core.structure")
UpfData = DataFactory("pseudo.upf")

if isinstance(data, dict):
return {key: _serialize_data(value) for key, value in data.items()}

if isinstance(data, BaseType):
return data.value

if isinstance(data, AbstractCode):
return data.full_label

if isinstance(data, Dict):
return data.get_dict()

if isinstance(data, List):
return data.get_list()

if isinstance(data, StructureData):
return data.get_formula()

if isinstance(data, UpfData):
return f"{data.element}<md5={data.md5}>"

if isinstance(data, RemoteData):
# For `RemoteData` we compute the hash of the repository. The value returned by `Node._get_hash` is not
# useful since it includes the hash of the absolute filepath and the computer UUID which vary between tests
return data.base.repository.hash()

if isinstance(data, KpointsData):
try:
return data.get_kpoints()
except AttributeError:
return data.get_kpoints_mesh()

if isinstance(data, SinglefileData):
return data.get_content()

if isinstance(data, Data):
return data.base.caching._get_hash()

return data


@pytest.fixture
def serialize_inputs():
"""Serialize the given process inputs into a dictionary with nodes turned into their value representation.
(Borrowed from aiida-quantumespresso/tests/conftest.py::serialize_builder)
:param builder: the process builder to serialize
:param input: the process inputs of type NodeManegerLink to serialize
:return: dictionary
"""

Expand Down Expand Up @@ -136,3 +194,18 @@ def _serialize_inputs(inputs: NodeLinksManager):
return _serialize_data(_inputs)

return _serialize_inputs


@pytest.fixture
def serialize_builder():
"""Serialize the builder into a dictionary with nodes turned into their value representation.
(Borrowed from aiida-quantumespresso/tests/conftest.py::serialize_builder)
:param builder: the process builder to serialize
:return: dictionary
"""

def _serialize_builder(builder: ProcessBuilder):
return _serialize_data(builder._inputs(prune=True))

return _serialize_builder
31 changes: 31 additions & 0 deletions tests/workflows/convergence/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
UpfData = DataFactory("pseudo.upf")


@pytest.mark.slow
@pytest.mark.parametrize(
"entry_point",
[
Expand Down Expand Up @@ -62,5 +63,35 @@ def test_run_default(
data_regression.check(serialize_inputs(ref_evaluate_node.inputs))


@pytest.mark.parametrize(
"entry_point,clean_workdir",
[
("sssp_workflow.convergence.eos", True),
("sssp_workflow.convergence.eos", False),
],
)
def test_builder_pseudo_as_upfdata(
entry_point,
clean_workdir,
pseudo_path,
code_generator,
serialize_builder,
data_regression,
):
_ConvergencWorkChain = WorkflowFactory(entry_point)
pseudo = UpfData.get_or_create(pseudo_path("Al"))

builder: ProcessBuilder = _ConvergencWorkChain.get_builder(
pseudo=pseudo,
protocol="test",
cutoff_list=[(20, 80), (30, 120)],
configuration="DC",
code=code_generator("pw"),
clean_workdir=clean_workdir,
)

data_regression.check(serialize_builder(builder))


# TODO: test not clean workdir
# TODO: test validator of _base convergence workchain working as expected
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
clean_workdir: false
code: pw-docker@localhost
configuration: DC
cutoff_list:
- - 20
- 80
- - 30
- 120
metadata:
call_link_label: caching
mpi_options:
max_wallclock_seconds: 1800
resources:
num_machines: 1
withmpi: false
parallelization: {}
protocol: test
pseudo: Al<md5=a2ca6568aad2214016a12794e7e55b1e>
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
clean_workdir: true
code: pw-docker@localhost
configuration: DC
cutoff_list:
- - 20
- 80
- - 30
- 120
metadata:
call_link_label: caching
mpi_options:
max_wallclock_seconds: 1800
resources:
num_machines: 1
withmpi: false
parallelization: {}
protocol: test
pseudo: Al<md5=a2ca6568aad2214016a12794e7e55b1e>

0 comments on commit 3770a2f

Please sign in to comment.