Skip to content

Commit

Permalink
Support for the new engine class in koopmans.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed Dec 3, 2024
1 parent cbfacab commit 07edbf9
Show file tree
Hide file tree
Showing 6 changed files with 492 additions and 85 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ dependencies = [
"voluptuous",
"aiida-quantumespresso~=4.6.0",
"aiida-wannier90-workflows @ git+https://github.com/aiidateam/aiida-wannier90-workflows",
"koopmans @ git+https://github.com/mikibonacci/koopmans@feature/ase-to-pw"
]

[project.urls]
Expand Down
4 changes: 2 additions & 2 deletions src/aiida_koopmans/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def generate_singlefiledata(filename, flines):

return file

def produce_wannier90_files(calc_w90,merge_directory_name,method="dfpt"):
def produce_wannier90_files(calc_w90,label,method="dfpt"):
"""producing the wannier90 files in the case of just one occ and/or one emp blocks.
Args:
Expand All @@ -34,7 +34,7 @@ def produce_wannier90_files(calc_w90,merge_directory_name,method="dfpt"):

standard_dictionary = {'hr_dat':hr_singlefile, "u_mat": u_singlefile, "centres_xyz": centres_singlefile}

if method == 'dfpt' and merge_directory_name == "emp":
if method == 'dfpt' and label == "emp":
u_dis_file = calc_w90.wchain.outputs.wannier90.retrieved.get_object_content('aiida' + '_u_dis.mat')
u_dis_singlefile = generate_singlefiledata('aiida' + '_u_dis.mat', u_dis_file)
standard_dictionary["u_dis_mat"] = u_dis_singlefile
Expand Down
164 changes: 164 additions & 0 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from koopmans.engines.engine import Engine
from koopmans.step import Step
from koopmans.calculators import Calc
from koopmans.pseudopotentials import read_pseudo_file

from aiida.engine import run_get_node, submit

from aiida_koopmans.utils import *

from aiida_pseudo.data.pseudo import UpfData

import time

import dill as pickle

from aiida import orm, load_profile
load_profile()

class AiiDAEngine(Engine):

"""
Step data is a dictionary containing the following information:
step_data = {calc.directory: {'workchain': workchain, 'remote_folder': remote_folder}}
and any other info we need for AiiDA.
"""

def __init__(self, *args, **kwargs):
self.blocking = kwargs.pop('blocking', True)
self.step_data = { # TODO: change to a better name
'configuration': kwargs.pop('configuration', None),
'steps': {}
}

# here we add the logic to populate configuration by default
# 1. we look for codes stored in AiiDA at localhost, e.g. pw-version@localhost,
# 2. we look for codes in the PATH,
# 3. if we don't find the code in AiiDA db but in the PATH, we store it in AiiDA db.
# 4. if we don't find the code in AiiDA db and in the PATH and not configuration is provided, we raise an error.
if self.step_data['configuration'] is None:
raise NotImplementedError("Configuration not provided")

# 5. if no resource info in configuration, we try to look at PARA_PREFIX env var.


super().__init__(*args, **kwargs)


def _run_steps(self, steps: tuple[Step, ...]) -> None:
try:
with open('step_data.pkl', 'rb') as f:
self.step_data = pickle.load(f) # this will overwrite the step_data[configuration], ie. if we change codes or res we will not see it if the file already exists.
except:
pass

self.from_scratch = False
for step in steps:
# self._step_running_message(step)
if step.directory in self.step_data['steps']:
continue
elif step.prefix in ['wannier90_preproc', 'pw2wannier90']:
print(f'skipping {step.prefix} step')
continue
else:
self.from_scratch = True

#step.run()
self.step_data['steps'][step.directory] = {}
builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
running = submit(builder)
# running = aiidawrapperwchain.submit(builder) # in the non-blocking case.
self.step_data['steps'][step.directory] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder}

#if self.from_scratch:
with open('step_data.pkl', 'wb') as f:
pickle.dump(self.step_data, f)

if not self.blocking and self.from_scratch:
raise CalculationSubmitted("Calculation submitted to AiiDA, non blocking")
elif self.blocking:
for step in self.step_data['steps'].values():
while not orm.load_node(step['workchain']).is_finished:
time.sleep(5)

for step in steps:
# convert from AiiDA to ASE results and populate ASE calculator
# TOBE put in a separate function
if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
continue
workchain = orm.load_node(self.step_data['steps'][step.directory]['workchain'])
if "remote_folder" in workchain.outputs:
self.step_data['steps'][step.directory]['remote_folder'] = workchain.outputs.remote_folder.pk
output = None
if step.ext_out == ".wout":
output = read_output_file(step, workchain.outputs.wannier90.retrieved)
elif step.ext_out in [".pwo",".kho"]:
output = read_output_file(step, workchain.outputs.retrieved)
if hasattr(output.calc, 'kpts'):
step.kpts = output.calc.kpts
else:
output = read_output_file(step, workchain.outputs.retrieved)
if step.ext_out in [".pwo",".wout",".kso",".kho"]:
step.calc = output.calc
step.results = output.calc.results
step.generate_band_structure(nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))

self._step_completed_message(step)

# If we reached here, all future steps should be performed from scratch
self.from_scratch = True

# dump again to have update the information
with open('step_data.pkl', 'wb') as f:
pickle.dump(self.step_data, f)

return

def load_old_calculator(self, calc: Calc):
raise NotImplementedError # load_old_calculator(calc)

def get_pseudo_data(self, workflow):
pseudo_data = {}
symbols_list = []
for symbol in workflow.pseudopotentials.keys():
symbols_list.append(symbol)

qb = orm.QueryBuilder()
qb.append(orm.Group, filters={'label': {'==': 'pseudo_group'}}, tag='pseudo_group')
qb.append(UpfData, filters={'attributes.element': {'in': symbols_list}}, with_group='pseudo_group')

for pseudo in qb.all():
with tempfile.TemporaryDirectory() as dirpath:
temp_file = pathlib.Path(dirpath) / pseudo[0].attributes.element + '.upf'
with pseudo[0].open(pseudo[0].attributes.element + '.upf', 'wb') as handle:
temp_file.write_bytes(handle.read())
pseudo_data[pseudo[0].attributes.element] = read_pseudo_file(temp_file)

return pseudo_data


def load_old_calculator(calc):
# This is a separate function so that it can be imported by other engines
loaded_calc = calc.__class__.fromfile(calc.directory / calc.prefix)

if loaded_calc.is_complete():
# If it is complete, load the results
calc.results = loaded_calc.results

# Check the convergence of the calculation
calc.check_convergence()

# Load k-points if relevant
if hasattr(loaded_calc, 'kpts'):
calc.kpts = loaded_calc.kpts

if isinstance(calc, ReturnsBandStructure):
calc.generate_band_structure()

if isinstance(calc, ProjwfcCalculator):
calc.generate_dos()

if isinstance(calc, PhCalculator):
calc.read_dynG()

return loaded_calc
74 changes: 50 additions & 24 deletions src/aiida_koopmans/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys

from aiida_koopmans.calculations.kcw import KcwCalculation
from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata
from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files

"""
ASE calculator MUST have `wchain` attribute (the related AiiDA WorkChain) to be able to use these functions!
Expand Down Expand Up @@ -272,15 +272,16 @@ def from_wann2kc_to_KcwCalculation(wann2kc_calculator):
builder.metadata = wann2kc_calculator.parameters.mode["metadata_kcw"]
builder.parent_folder = wann2kc_calculator.parent_folder

if hasattr(wann2kc_calculator, "w90_files"):
builder.wann_emp_u_mat = wann2kc_calculator.w90_files["emp"]["u_mat"]
builder.wann_emp_u_dis_mat = wann2kc_calculator.w90_files["emp"][
if hasattr(wann2kc_calculator, "wannier90_files"):
builder.wann_emp_u_mat = wann2kc_calculator.wannier90_files["emp"]["u_mat"]
builder.wann_u_mat = wann2kc_calculator.wannier90_files["occ"]["u_mat"]
builder.wann_emp_u_dis_mat = wann2kc_calculator.wannier90_files["emp"][
"u_dis_mat"
]
builder.wann_centres_xyz = wann2kc_calculator.w90_files["occ"][
builder.wann_centres_xyz = wann2kc_calculator.wannier90_files["occ"][
"centres_xyz"
]
builder.wann_emp_centres_xyz = wann2kc_calculator.w90_files["emp"][
builder.wann_emp_centres_xyz = wann2kc_calculator.wannier90_files["emp"][
"centres_xyz"
]

Expand Down Expand Up @@ -354,14 +355,14 @@ def from_kcwham_to_KcwCalculation(kcw_calculator):
builder.metadata = kcw_calculator.parameters.mode["metadata_kcw"]
builder.parent_folder = kcw_calculator.parent_folder

if hasattr(kcw_calculator, "w90_files") and control_dict.get(
if hasattr(kcw_calculator, "wannier90_files") and control_dict.get(
"read_unitary_matrix", False
):
builder.wann_u_mat = kcw_calculator.w90_files["occ"]["u_mat"]
builder.wann_emp_u_mat = kcw_calculator.w90_files["emp"]["u_mat"]
builder.wann_emp_u_dis_mat = kcw_calculator.w90_files["emp"]["u_dis_mat"]
builder.wann_centres_xyz = kcw_calculator.w90_files["occ"]["centres_xyz"]
builder.wann_emp_centres_xyz = kcw_calculator.w90_files["emp"][
builder.wann_u_mat = kcw_calculator.wannier90_files["occ"]["u_mat"]
builder.wann_emp_u_mat = kcw_calculator.wannier90_files["emp"]["u_mat"]
builder.wann_emp_u_dis_mat = kcw_calculator.wannier90_files["emp"]["u_dis_mat"]
builder.wann_centres_xyz = kcw_calculator.wannier90_files["occ"]["centres_xyz"]
builder.wann_emp_centres_xyz = kcw_calculator.wannier90_files["emp"][
"centres_xyz"
]

Expand Down Expand Up @@ -438,14 +439,14 @@ def from_kcwscreen_to_KcwCalculation(kcw_calculator):
builder.metadata = kcw_calculator.parameters.mode["metadata_kcw"]
builder.parent_folder = kcw_calculator.parent_folder

if hasattr(kcw_calculator, "w90_files") and control_dict.get(
if hasattr(kcw_calculator, "wannier90_files") and control_dict.get(
"read_unitary_matrix", False
):
builder.wann_u_mat = kcw_calculator.w90_files["occ"]["u_mat"]
builder.wann_emp_u_mat = kcw_calculator.w90_files["emp"]["u_mat"]
builder.wann_emp_u_dis_mat = kcw_calculator.w90_files["emp"]["u_dis_mat"]
builder.wann_centres_xyz = kcw_calculator.w90_files["occ"]["centres_xyz"]
builder.wann_emp_centres_xyz = kcw_calculator.w90_files["emp"][
builder.wann_u_mat = kcw_calculator.wannier90_files["occ"]["u_mat"]
builder.wann_emp_u_mat = kcw_calculator.wannier90_files["emp"]["u_mat"]
builder.wann_emp_u_dis_mat = kcw_calculator.wannier90_files["emp"]["u_dis_mat"]
builder.wann_centres_xyz = kcw_calculator.wannier90_files["occ"]["centres_xyz"]
builder.wann_emp_centres_xyz = kcw_calculator.wannier90_files["emp"][
"centres_xyz"
]

Expand Down Expand Up @@ -661,12 +662,21 @@ def wrapper_aiida_trigger(self):
if self.parameters.mode == "ase":
return _fetch_linked_files(self)
else: # if pseudo linking, src_calc = None
self.wannier90_files = {}
for dest_filename, (src_calc, src_filename, symlink, recursive_symlink, overwrite) in self.linked_files.items():
# check the we have only one src_calc!!!
if hasattr(src_calc,"wchain"):
self.parent_folder = src_calc.wchain.outputs.remote_folder
elif hasattr(src_calc,"dst_file"):
pass #linik w90_files.
# semi-complex logic to have the w90 files for the wann2kc:
if "wannier90" in dest_filename:
if "emp" in dest_filename:
self.wannier90_files["emp"] = produce_wannier90_files(src_calc,"emp")
else:
self.wannier90_files["occ"] = produce_wannier90_files(src_calc,"occ")
else:
self.parent_folder = src_calc.wchain.outputs.remote_folder
elif hasattr(src_calc,"wannier90_files"):
pass #link wannier90_files in case we merged multiple Wannierization.
if self.wannier90_files == {}: delattr(self, "wannier90_files")
return wrapper_aiida_trigger

# get files to manipulate further.
Expand Down Expand Up @@ -730,22 +740,38 @@ def wrapper_aiida_trigger(self):
return wrapper_aiida_trigger

def aiida_merge_wannier_files_trigger(merge_wannier_files):
# needed to populate the self.w90_files dictionary in the wannierize workflow.
# needed to populate the self.wannier90_files dictionary in the wannierize workflow.
@functools.wraps(merge_wannier_files)
def wrapper_aiida_trigger(self, block, merge_directory, prefix):
merge_wannier_files(self,block, merge_directory, prefix)
if self.parameters.mode == "ase":
return
else:
self.w90_files[merge_directory] = {
self.wannier90_files[merge_directory] = {
'hr_dat': self.merge_hr_proc.singlefiledata,
}
del self.merge_hr_proc.singlefiledata
if self.parameters.method == 'dfpt':
self.w90_files[merge_directory].update({
self.wannier90_files[merge_directory].update({
"u_mat": self.merge_u_proc.singlefiledata,
"centres_xyz": self.merge_centers_proc.singlefiledata,
})
del self.merge_u_proc.singlefiledata
del self.merge_centers_proc.singlefiledata
return wrapper_aiida_trigger

def aiida_trigger_run_calculator(run_calculator):
# This wraps the run_calculator method.
@functools.wraps(run_calculator)
def wrapper_aiida_trigger(self, calc):
if calc.prefix in ["wannier90_preproc","pw2wannier90"]:
return
elif calc.prefix in ["wannier90"]:
from koopmans import calculators
calc_nscf = [c for c in self.calculations if isinstance(
c, calculators.PWCalculator) and c.parameters.calculation == 'nscf'][-1]
self.link(calc_nscf, calc_nscf.parameters.outdir, calc, "")
return run_calculator(self,calc)
else:
return run_calculator(self,calc)
return wrapper_aiida_trigger
58 changes: 0 additions & 58 deletions src/aiida_koopmans/parsers.py

This file was deleted.

Loading

0 comments on commit 07edbf9

Please sign in to comment.