Skip to content

Commit

Permalink
Adding support for wann2kc and kc_ham.
Browse files Browse the repository at this point in the history
still kc_screen and alphas are not supported.
  • Loading branch information
mikibonacci committed Dec 6, 2024
1 parent cef1911 commit c6f0978
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 21 deletions.
53 changes: 37 additions & 16 deletions src/aiida_koopmans/calculations/kcw.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# -*- coding: utf-8 -*-
"""`CalcJob` implementation for the kcw.x code of Quantum ESPRESSO."""
from pathlib import Path
import os

from aiida import orm
from aiida.plugins import DataFactory
from aiida_quantumespresso.calculations.namelists import NamelistsCalculation

SingleFileData = DataFactory('core.singlefile')

class KcwCalculation(NamelistsCalculation):
"""`CalcJob` implementation for the kcw.x code of Quantum ESPRESSO.
Expand Down Expand Up @@ -45,13 +44,13 @@ def define(cls, spec):
spec.input('kpoints', valid_type=orm.KpointsData, help='kpoint path if do_bands=True in the parameters', required=False)
#spec.input('wann_occ_hr', valid_type=SingleFileData, help='wann_occ_hr', required=False)
#spec.input('wann_emp_hr', valid_type=SingleFileData, help='wann_emp_hr', required=False)
spec.input('alpha_occ', valid_type=SingleFileData, help='alpha_occ', required=False)
spec.input('alpha_emp', valid_type=SingleFileData, help='alpha_emp', required=False)
spec.input('wann_u_mat', valid_type=SingleFileData, help='wann_occ_u', required=False)
spec.input('wann_emp_u_mat', valid_type=SingleFileData, help='wann_emp_u', required=False)
spec.input('wann_emp_u_dis_mat', valid_type=SingleFileData, help='wann_dis_u', required=False)
spec.input('wann_centres_xyz', valid_type=SingleFileData, help='wann_occ_centres', required=False)
spec.input('wann_emp_centres_xyz', valid_type=SingleFileData, help='wann_emp_centres', required=False)
spec.input('alpha_occ', valid_type=(orm.SinglefileData, orm.RemoteData), help='alpha_occ', required=False)
spec.input('alpha_emp', valid_type=(orm.SinglefileData, orm.RemoteData), help='alpha_emp', required=False)
spec.input('wann_u_mat', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_occ_u', required=False)
spec.input('wann_emp_u_mat', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_emp_u', required=False)
spec.input('wann_emp_u_dis_mat', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_dis_u', required=False)
spec.input('wann_centres_xyz', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_occ_centres', required=False)
spec.input('wann_emp_centres_xyz', valid_type=(orm.SinglefileData, orm.RemoteData), help='wann_emp_centres', required=False)
spec.input('settings', valid_type=orm.Dict, required=True, default=lambda: orm.Dict({
'CMDLINE': ["-in", cls._DEFAULT_INPUT_FILE],
}), help='Use an additional node for special settings',) #validator=validate_parameters,)
Expand All @@ -78,12 +77,27 @@ def define(cls, spec):

def prepare_for_submission(self, folder):
calcinfo = super().prepare_for_submission(folder)

for wann_file in ['wann_u_mat','wann_emp_u_mat','wann_emp_u_dis_mat','wann_centres_xyz','wann_emp_centres_xyz']:
if hasattr(self.inputs,wann_file):
wannier_singelfiledata = getattr(self.inputs, wann_file)
calcinfo.local_copy_list.append((wannier_singelfiledata.uuid, wannier_singelfiledata.filename, wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida")))


for wann_input in ['wann_u_mat','wann_emp_u_mat','wann_emp_u_dis_mat','wann_centres_xyz','wann_emp_centres_xyz']:
wann_parent = getattr(self.inputs, wann_input, None)
if isinstance(wann_parent, orm.SinglefileData): # local copy to be send to the remote
calcinfo.local_copy_list.append((wann_parent.uuid, wann_parent.filename, wann_input.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida")))
elif isinstance(wann_parent, orm.RemoteData):
# if remote, we symlink all the files
if wann_input == 'wann_u_mat':
for wann_file in ['wann_u_mat', 'wann_centres_xyz']:
calcinfo.remote_symlink_list.append(
create_symlink_tuple(parent_folder = wann_parent,
filename = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida"),
target = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida")))
elif wann_input == 'wann_emp_u_mat':
for wann_file in ['wann_emp_u_mat', 'wann_emp_centres_xyz', 'wann_emp_u_dis_mat']:
calcinfo.remote_symlink_list.append(
create_symlink_tuple(parent_folder = wann_parent,
filename = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida").replace("_emp",""),
target = wann_file.replace("_mat",".mat").replace("_xyz",".xyz").replace("wann","aiida")))

# TODO: fix the alphas copy
for alpha_file in ['alpha_occ','alpha_emp']:
if hasattr(self.inputs,alpha_file):
suffix = alpha_file.replace("alpha_occ","").replace("alpha_emp","_empty")
Expand All @@ -97,7 +111,14 @@ def prepare_for_submission(self, folder):
handle.write(kpoints_card)

return calcinfo


def create_symlink_tuple(parent_folder: orm.RemoteData, filename: str, target: str):
return (
parent_folder.computer.uuid,
os.path.join(parent_folder.get_remote_path(),
filename), target
)

def prepare_kpoints_card(kpoints=None):
# from the BasePwCpInputGenerator, I had to move it here as we cannot just inherit
from aiida.common import exceptions
Expand Down
7 changes: 5 additions & 2 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def update_statuses(self) -> None:
elif workchain.is_finished or workchain.is_excepted or workchain.is_killed:
self._step_failed_message_by_uid(uid)
self.set_status_by_uid(uid, Status.FAILED)
raise ValueError(f"Workchain {workchain.pk} failed.")

return

Expand All @@ -155,15 +156,17 @@ def load_results(self, step: Step) -> None:
output = None
if step.ext_out == ".wout":
output = read_output_file(step, workchain.outputs.wannier90.retrieved)
elif step.ext_out in [".pwo",".kho"]:
if "remote_folder" in workchain.outputs.wannier90:
self.step_data['steps'][step.uid]['remote_folder'] = workchain.outputs.wannier90.remote_folder.pk
elif step.ext_out in [".pwo",".w2ko",".kso",".kho"]:
output = read_output_file(step, workchain.outputs.retrieved)
if hasattr(output.calc, 'kpts'):
step.kpts = output.calc.kpts
else:
output = read_output_file(step, workchain.outputs.retrieved)


if step.ext_out in [".pwo",".pro",".wout",".kso",".kho"]:
if step.ext_out in [".pwo",".pro",".wout",".w2ko",".kso",".kho"]:
step.calc = output.calc
step.results = output.calc.results
#if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))
Expand Down
154 changes: 151 additions & 3 deletions src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,15 +288,163 @@ def get_projwfc_builder_from_ase(projwfc_calculator, step_data=None):

return builder, step_data

def get_kcw_builder_from_ase(kcw_calculator, step_data=None):

from aiida import load_profile, orm
load_profile()

aiida_inputs = step_data["configuration"]

# here we should find the parent folder and the wann files, merged or not (single block for emp or occ manifold).
parent_folder = None
wann_u_mat = None
wann_emp_u_mat = None
wann_emp_u_dis_mat = None
wann_centres_xyz = None
wann_emp_centres_xyz = None
for step_uid, val in step_data['steps'].items():
if "nscf" in step_uid:
nscf = orm.load_node(val["workchain"])
parent_folder = nscf.outputs.remote_folder
if "kcw_wannier" in step_uid:
w2kc = orm.load_node(val["workchain"])
parent_folder = w2kc.outputs.remote_folder

# SinglefileData merged files:
if "merge_occ_wannier_u" in step_uid:
wann_u_mat = orm.load_node(val['wannier90_u.mat'])
if "merge_occ_wannier_centers" in step_uid:
wann_centres_xyz = orm.load_node(val['wannier90_centres.xyz'])
if "merge_emp_wannier_u" in step_uid: # TODO: check if this is correct
wann_emp_u_mat = orm.load_node(val['wannier90_u.mat'])
if "merge_emp_wannier_centers" in step_uid:
wann_emp_centres_xyz = orm.load_node(val['wannier90_centres.xyz'])
if "merge_emp_wannier_u_dis" in step_uid:
wann_emp_u_dis_mat = orm.load_node(val['wannier90_u_dis.mat'])


# RemoteData folders: this is when only one block in occ or emp manifold.
# Instead of the SinglefileData (as searched above), we have only the RemoteData
# of the wannnier90 calc.
# TODO: explain this logic.
tmp_wann_emp_u_mat = None
for step_uid, val in step_data['steps'].items():

# the first hit is the single block of occ manifold,
# so we assign it and then we never hit again this block.
if not wann_u_mat and "03-wannier90" in step_uid:
wann_u_mat = orm.load_node(val["remote_folder"])

# we continue updating it up to the last hit.
# the last hit is the single block of emp manifold
if not wann_emp_u_mat and "03-wannier90" in step_uid:
tmp_wann_emp_u_mat = orm.load_node(val["remote_folder"])

if tmp_wann_emp_u_mat: wann_emp_u_mat = tmp_wann_emp_u_mat

# get the kcw calculator ext_out: we have three cases: w2ko, kso, kho
ext_out = kcw_calculator.ext_out

control_namelist = kcw_inputs_keys[ext_out]['control']
wannier_namelist = kcw_inputs_keys[ext_out]['wannier']

control_dict = {
k: v if k in control_namelist else None
for k, v in kcw_calculator.parameters.items()
if k not in ALL_BLOCKED_KEYWORDS
}

control_dict["calculation"] = "wann2kcw"
for k in list(control_dict):
if control_dict[k] is None:
control_dict.pop(k)

wannier_dict = {
k: v if k in wannier_namelist else None
for k, v in kcw_calculator.parameters.items()
# ? Using all here, as blocked Wannier90 keywords doesn't contain 'seedname', but kcw does
if k not in ALL_BLOCKED_KEYWORDS
}
for k in list(wannier_dict):
if wannier_dict[k] is None:
wannier_dict.pop(k)

screening_dict = {
k: v if k in kcs_keys['screen'] else None
for k, v in kcw_calculator.parameters.items()
if k not in ALL_BLOCKED_KEYWORDS
}
for k in list(screening_dict):
if screening_dict[k] is None:
screening_dict.pop(k)

ham_dict = {
k: v if k in kch_keys['ham'] else None
for k, v in kcw_calculator.parameters.items()
if k not in ALL_BLOCKED_KEYWORDS
}
for k in list(ham_dict):
if ham_dict[k] is None:
ham_dict.pop(k)

kcw_params = {
"CONTROL": control_dict,
"WANNIER": wannier_dict,
}
if ext_out == ".kso":
kcw_params["SCREEN"] = screening_dict
kcw_params["CONTROL"]["calculation"] = "screen"
elif ext_out == ".kho":
kcw_params["CONTROL"]["calculation"] = "ham"
kcw_params["HAM"] = ham_dict

# builder.
builder = KcwCalculation.get_builder()
builder.parameters = orm.Dict(kcw_params)
builder.code = orm.load_code(aiida_inputs["kcw_code"])

builder.metadata = aiida_inputs["metadata"]
if "metadata_kcw" in aiida_inputs:
builder.metadata = aiida_inputs["metadata_kcw"]

if ext_out == ".kho":
breakpoint()
# I provide kpoints as an array (output in the wannierized band structure), so I need to convert them.
kpoints = orm.KpointsData()
kpoints.set_kpoints(kcw_calculator._parameters.kpts.kpts, cartesian=False)
builder.kpoints = kpoints

builder.parent_folder = parent_folder

if control_dict.get(
"read_unitary_matrix", False
):
if wann_u_mat: builder.wann_u_mat = wann_u_mat
if wann_emp_u_mat: builder.wann_emp_u_mat = wann_emp_u_mat
if wann_emp_u_dis_mat: builder.wann_emp_u_dis_mat = wann_emp_u_dis_mat
if wann_centres_xyz: builder.wann_centres_xyz = wann_centres_xyz
if wann_emp_centres_xyz: builder.wann_emp_centres_xyz = wann_centres

if hasattr(kcw_calculator, "alphas"): # TODO: add support for this.
builder.alpha_occ = kcw_calculator.alphas_files["alpha"]
builder.alpha_emp = kcw_calculator.alphas_files["alpha_empty"]

return builder, step_data

## Here we have the mapping for the calculators initialization. used in the `aiida_calculate_trigger`.
mapping_calculators = {
".pwo" : get_PwBaseWorkChain_from_ase,
".wout": get_Wannier90BandsWorkChain_builder_from_ase,
".pro": get_projwfc_builder_from_ase,
#".w2ko": from_wann2kc_to_KcwCalculation,
#".kso": from_kcwscreen_to_KcwCalculation,
#".kho": from_kcwham_to_KcwCalculation,
".w2ko": get_kcw_builder_from_ase,
".kso": get_kcw_builder_from_ase,
".kho": get_kcw_builder_from_ase,
}

kcw_inputs_keys = {
".w2ko": w2kcw_keys,
".kso": kcs_keys,
".kho": kch_keys,
}

# read the output file, mimicking the read_results method of ase-koopmans: https://github.com/elinscott/ase_koopmans/blob/master/ase/calculators/espresso/_espresso.py
Expand Down

0 comments on commit c6f0978

Please sign in to comment.