Skip to content

Commit

Permalink
Modification of the AiiDAEngine.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed Dec 4, 2024
1 parent 07edbf9 commit c906654
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 98 deletions.
2 changes: 1 addition & 1 deletion notebooks/h20_wtree.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ase import Atoms\n",
"from ase_koopmans import Atoms\n",
"import copy\n",
"\n",
"H2O = Atoms(atoms,\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/koop_insp.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@
"source": [
"## Trying aiida-shell\n",
"\n",
"1. from ase calculator to aiida-shell inputs\n",
"1. from ase_koopmans calculator to aiida-shell inputs\n",
"\n",
"The first thing to do is to write the input file."
]
Expand Down
2 changes: 1 addition & 1 deletion notebooks/multiple_pw_conv.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ase import Atoms\n",
"from ase_koopmans import Atoms\n",
"import copy\n",
"\n",
"H2O = Atoms(atoms,\n",
Expand Down
2 changes: 1 addition & 1 deletion quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@
"metadata": {},
"outputs": [],
"source": [
"from ase.io import read\n",
"from ase_koopmans.io import read\n",
"\n",
"ozone = read(\"/home/jovyan/work/koopmans_calcs/tutorial_1/ozon.xsf\")\n",
"ozone.cell = [[14.1738, 0.0, 0.0],\n",
Expand Down
167 changes: 80 additions & 87 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from koopmans.step import Step
from koopmans.calculators import Calc
from koopmans.pseudopotentials import read_pseudo_file
from koopmans.status import Status

from aiida.engine import run_get_node, submit

Expand All @@ -23,7 +24,6 @@ class AiiDAEngine(Engine):
step_data = {calc.directory: {'workchain': workchain, 'remote_folder': remote_folder}}
and any other info we need for AiiDA.
"""

def __init__(self, *args, **kwargs):
self.blocking = kwargs.pop('blocking', True)
self.step_data = { # TODO: change to a better name
Expand All @@ -43,77 +43,97 @@ def __init__(self, *args, **kwargs):


super().__init__(*args, **kwargs)

def run(self, step: Step):

self.get_status(step)
if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
#self._step_completed_message(step)
return

self.step_data['steps'][step.uid] = {} # maybe not needed
builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
running = submit(builder)
# running = aiidawrapperwchain.submit(builder) # in the non-blocking case.
self.step_data['steps'][step.uid] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder}

self.set_status(step, Status.RUNNING)

return


def _run_steps(self, steps: tuple[Step, ...]) -> None:
def load_step_data(self):
try:
with open('step_data.pkl', 'rb') as f:
self.step_data = pickle.load(f) # this will overwrite the step_data[configuration], ie. if we change codes or res we will not see it if the file already exists.
except:
# this will overwrite the step_data[configuration],
# i.e. if we change codes or res we will not see it if
# the file already exists.
self.step_data = pickle.load(f)
except FileNotFoundError:
pass

self.from_scratch = False
for step in steps:
# self._step_running_message(step)
if step.directory in self.step_data['steps']:
continue
elif step.prefix in ['wannier90_preproc', 'pw2wannier90']:
print(f'skipping {step.prefix} step')
continue
else:
self.from_scratch = True

#step.run()
self.step_data['steps'][step.directory] = {}
builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
running = submit(builder)
# running = aiidawrapperwchain.submit(builder) # in the non-blocking case.
self.step_data['steps'][step.directory] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder}

#if self.from_scratch:
def dump_step_data(self):
with open('step_data.pkl', 'wb') as f:
pickle.dump(self.step_data, f)
pickle.dump(self.step_data, f)

def get_status(self, step: Step) -> Status:
return self.get_status_by_uid(step.uid)

if not self.blocking and self.from_scratch:
raise CalculationSubmitted("Calculation submitted to AiiDA, non blocking")
elif self.blocking:
for step in self.step_data['steps'].values():
while not orm.load_node(step['workchain']).is_finished:
time.sleep(5)

for step in steps:

def get_status_by_uid(self, uid: str) -> Status:
self.load_step_data()
if uid not in self.step_data['steps']:
self.step_data['steps'][uid] = {'status': Status.NOT_STARTED}
return self.step_data['steps'][uid]['status']

def set_status(self, step: Step, status: Status):
self.set_status_by_uid(step.uid, status)

def set_status_by_uid(self, uid: str, status: Status):
self.step_data['steps'][uid]['status'] = status
self.dump_step_data()

def update_statuses(self) -> None:
time.sleep(5)
for uid in self.step_data['steps']:
# convert from AiiDA to ASE results and populate ASE calculator
# TOBE put in a separate function
if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
if not self.get_status_by_uid(uid) == Status.RUNNING:
continue
workchain = orm.load_node(self.step_data['steps'][step.directory]['workchain'])
if "remote_folder" in workchain.outputs:
self.step_data['steps'][step.directory]['remote_folder'] = workchain.outputs.remote_folder.pk
output = None
if step.ext_out == ".wout":
output = read_output_file(step, workchain.outputs.wannier90.retrieved)
elif step.ext_out in [".pwo",".kho"]:
output = read_output_file(step, workchain.outputs.retrieved)
if hasattr(output.calc, 'kpts'):
step.kpts = output.calc.kpts
else:
output = read_output_file(step, workchain.outputs.retrieved)
if step.ext_out in [".pwo",".wout",".kso",".kho"]:
step.calc = output.calc
step.results = output.calc.results
step.generate_band_structure(nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))

workchain = orm.load_node(self.step_data['steps'][uid]['workchain'])
if workchain.is_finished_ok:
self.set_status_by_uid(uid, Status.COMPLETED)

elif workchain.is_finished or workchain.is_excepted or workchain.is_killed:
self.set_status_by_uid(uid, Status.FAILED)

return

def load_results(self, step: Step) -> None:

self.load_step_data()
workchain = orm.load_node(self.step_data['steps'][step.uid]['workchain'])
if "remote_folder" in workchain.outputs:
self.step_data['steps'][step.uid]['remote_folder'] = workchain.outputs.remote_folder.pk
output = None
if step.ext_out == ".wout":
output = read_output_file(step, workchain.outputs.wannier90.retrieved)
elif step.ext_out in [".pwo",".kho"]:
output = read_output_file(step, workchain.outputs.retrieved)
if hasattr(output.calc, 'kpts'):
step.kpts = output.calc.kpts
else:
output = read_output_file(step, workchain.outputs.retrieved)
if step.ext_out in [".pwo",".wout",".kso",".kho"]:
step.calc = output.calc
step.results = output.calc.results
step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))

self._step_completed_message(step)

# If we reached here, all future steps should be performed from scratch
self.from_scratch = True

# dump again to have update the information
with open('step_data.pkl', 'wb') as f:
pickle.dump(self.step_data, f)

self.dump_step_data()

return


def load_old_calculator(self, calc: Calc):
raise NotImplementedError # load_old_calculator(calc)

Expand All @@ -135,30 +155,3 @@ def get_pseudo_data(self, workflow):
pseudo_data[pseudo[0].attributes.element] = read_pseudo_file(temp_file)

return pseudo_data


def load_old_calculator(calc):
# This is a separate function so that it can be imported by other engines
loaded_calc = calc.__class__.fromfile(calc.directory / calc.prefix)

if loaded_calc.is_complete():
# If it is complete, load the results
calc.results = loaded_calc.results

# Check the convergence of the calculation
calc.check_convergence()

# Load k-points if relevant
if hasattr(loaded_calc, 'kpts'):
calc.kpts = loaded_calc.kpts

if isinstance(calc, ReturnsBandStructure):
calc.generate_band_structure()

if isinstance(calc, ProjwfcCalculator):
calc.generate_dos()

if isinstance(calc, PhCalculator):
calc.read_dynG()

return loaded_calc
4 changes: 2 additions & 2 deletions src/aiida_koopmans/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from aiida.orm import Code, Computer
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_wannier90.calculations.wannier90 import Wannier90Calculation
from ase import io
from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys
from ase_koopmans import io
from ase_koopmans.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys

from aiida_koopmans.calculations.kcw import KcwCalculation
from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files
Expand Down
2 changes: 1 addition & 1 deletion src/aiida_koopmans/parsers/kcw.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pathlib import Path
import pathlib
import tempfile
from ase import io
from ase_koopmans import io

from aiida.orm import Dict

Expand Down
16 changes: 12 additions & 4 deletions src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from aiida.orm import Code, Computer
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_wannier90.calculations.wannier90 import Wannier90Calculation
from ase import io
from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys

from ase import Atoms
from ase_koopmans import Atoms as AtomsKoopmans
from ase_koopmans import io
from ase_koopmans.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys

from aiida_koopmans.calculations.kcw import KcwCalculation
from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files
Expand Down Expand Up @@ -54,7 +57,12 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):
"""
aiida_inputs = step_data['configuration']
calc_params = pw_calculator._parameters
structure = orm.StructureData(ase=pw_calculator.atoms) # TODO: only one sdata, stored in the step_data dict. but some cases have output structure diff from input.

if isinstance(pw_calculator.atoms, AtomsKoopmans):
ase_atoms = Atoms.fromdict(pw_calculator.atoms.todict())

# WE NEED TO USE THE INPUT STRUCTURE OF SCF, WHEN WE DO NSCF
structure = orm.StructureData(ase=ase_atoms) # TODO: only one sdata, stored in the step_data dict. but some cases have output structure diff from input.

pw_overrides = {
"CONTROL": {},
Expand Down Expand Up @@ -93,7 +101,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):
# here we need explicit kpoints
builder.kpoints.set_kpoints(calc_params["kpts"].kpts,cartesian=False) # TODO: check cartesian false is correct.

parent_calculators = [f[0].directory for f in pw_calculator.linked_files.values() if f[0] is not None]
parent_calculators = [f[0].uid for f in pw_calculator.linked_files.values() if f[0] is not None]
if len(set(parent_calculators)) > 1:
raise ValueError("More than one parent calculator found.")
elif len(set(parent_calculators)) == 1:
Expand Down

0 comments on commit c906654

Please sign in to comment.