Skip to content

Commit

Permalink
Merge pull request #10 from GeigerJ2/projwfc-bands-workflow
Browse files Browse the repository at this point in the history
First working version of projwfc calc via temporarily setting calculator dir.
  • Loading branch information
mikibonacci authored Dec 5, 2024
2 parents bf0a97e + fd2230a commit b301dc6
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 44 deletions.
77 changes: 46 additions & 31 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,38 @@
load_profile()

class AiiDAEngine(Engine):

"""
Step data is a dictionary containing the following information:
step_data = {calc.directory: {'workchain': workchain, 'remote_folder': remote_folder}}
and any other info we need for AiiDA.
and any other info we need for AiiDA.
"""
def __init__(self, *args, **kwargs):
self.blocking = kwargs.pop('blocking', True)
self.step_data = { # TODO: change to a better name
'configuration': kwargs.pop('configuration', None),
'steps': {}
}
self.skip_message = False

# here we add the logic to populate configuration by default
# 1. we look for codes stored in AiiDA at localhost, e.g. pw-version@localhost,
# 2. we look for codes in the PATH,
# 2. we look for codes in the PATH,
# 3. if we don't find the code in AiiDA db but in the PATH, we store it in AiiDA db.
# 4. if we don't find the code in AiiDA db and in the PATH and not configuration is provided, we raise an error.
if self.step_data['configuration'] is None:
raise NotImplementedError("Configuration not provided")

# 5. if no resource info in configuration, we try to look at PARA_PREFIX env var.


super().__init__(*args, **kwargs)

def run(self, step: Step):

self.get_status(step)
if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
return

self.step_data['steps'][step.uid] = {} # maybe not needed
builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
running = submit(builder)
Expand All @@ -60,25 +58,25 @@ def run(self, step: Step):

# The below will be passed to the context, so we will need to store also the instance of the submitted workchain, if in KoopmansWorkChain.
self.step_data['steps'][step.uid] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder}

self.set_status(step, Status.RUNNING)
return

return

def load_step_data(self):
try:
with open('step_data.pkl', 'rb') as f:
# this will overwrite the step_data[configuration],
# i.e. if we change codes or res we will not see it if
# this will overwrite the step_data[configuration],
# i.e. if we change codes or res we will not see it if
# the file already exists.
self.step_data = pickle.load(f)
self.step_data = pickle.load(f)
except FileNotFoundError:
pass

def dump_step_data(self):
with open('step_data.pkl', 'wb') as f:
pickle.dump(self.step_data, f)

def get_status(self, step: Step) -> Status:
status = self.get_status_by_uid(step.uid)
#print(f"Getting status for step {step.uid}: {status}")
Expand All @@ -90,36 +88,35 @@ def get_status_by_uid(self, uid: str) -> Status:
if uid not in self.step_data['steps']:
self.step_data['steps'][uid] = {'status': Status.NOT_STARTED}
return self.step_data['steps'][uid]['status']

def set_status(self, step: Step, status: Status):
self.set_status_by_uid(step.uid, status)
#print(f"Step {step.uid} is {status}")

def set_status_by_uid(self, uid: str, status: Status):
self.step_data['steps'][uid]['status'] = status
self.dump_step_data()

def update_statuses(self) -> None:

time.sleep(1)
for uid in self.step_data['steps']:

if not self.get_status_by_uid(uid) == Status.RUNNING:
continue

workchain = orm.load_node(self.step_data['steps'][uid]['workchain'])
if workchain.is_finished_ok:
self._step_completed_message_by_uid(uid)
self.set_status_by_uid(uid, Status.COMPLETED)

elif workchain.is_finished or workchain.is_excepted or workchain.is_killed:
self._step_failed_message_by_uid(uid)
self.set_status_by_uid(uid, Status.FAILED)

return

def load_results(self, step: Step) -> None:

self.load_step_data()

if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
Expand All @@ -141,10 +138,27 @@ def load_results(self, step: Step) -> None:
step.calc = output.calc
step.results = output.calc.results
if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))


self._step_completed_message(step)

if step.ext_out in [".pro"]:

pdos_dir = dump_pdos_outputs(step, workchain.outputs.retrieved)
prev_dir = step.directory
step.directory = pdos_dir

try:
step.generate_dos()
except ValueError:
# ValueError: Must provide energies to create a GridDOSCollection without any DOS data.
pass
finally:
from aiida_koopmans.utils import delete_directory
delete_directory(pdos_dir.parent)
step.directory = prev_dir

self.dump_step_data()



def load_old_calculator(self, calc: Calc):
raise NotImplementedError # load_old_calculator(calc)

Expand All @@ -160,6 +174,7 @@ def get_pseudopotential(self, library: str, element: str):
temp_file = pathlib.Path(dirpath) / (pseudo[0].base.attributes.all['element'] + '.upf')
with pseudo[0].open(pseudo[0].base.attributes.all['element'] + '.upf', 'rb') as handle:
temp_file.write_bytes(handle.read())

pseudo_data = read_pseudo_file(temp_file)

if not pseudo_data:
Expand Down
108 changes: 95 additions & 13 deletions src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import shutil
import pathlib
import tempfile
Expand All @@ -9,6 +8,7 @@
from aiida.common.exceptions import NotExistent
from aiida.orm import Code, Computer
from aiida_quantumespresso.calculations.pw import PwCalculation
from aiida_quantumespresso.calculations.projwfc import ProjwfcCalculation
from aiida_wannier90.calculations.wannier90 import Wannier90Calculation

from ase import Atoms
Expand All @@ -22,8 +22,9 @@
LOCALHOST_NAME = "localhost-test"
KCW_BLOCKED_KEYWORDS = [t[1] for t in KcwCalculation._blocked_keywords]
PW_BLOCKED_KEYWORDS = [t[1] for t in PwCalculation._blocked_keywords]
PROJWFC_BLOCKED_KEYWORDS = [t[1] for t in ProjwfcCalculation._blocked_keywords]
WANNIER90_BLOCKED_KEYWORDS = [t[1] for t in Wannier90Calculation._BLOCKED_PARAMETER_KEYS]
ALL_BLOCKED_KEYWORDS = KCW_BLOCKED_KEYWORDS + PW_BLOCKED_KEYWORDS + WANNIER90_BLOCKED_KEYWORDS + [f'celldm({i})' for i in range (1,7)]
ALL_BLOCKED_KEYWORDS = KCW_BLOCKED_KEYWORDS + PW_BLOCKED_KEYWORDS + WANNIER90_BLOCKED_KEYWORDS + PROJWFC_BLOCKED_KEYWORDS + [f'celldm({i})' for i in range (1,7)]

def get_builder_from_ase(calculator, step_data=None):
return mapping_calculators[calculator.ext_out](calculator, step_data)
Expand All @@ -38,7 +39,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):

aiida_inputs = step_data['configuration']
calc_params = pw_calculator._parameters

structure = None
parent_folder = None
for step, val in step_data['steps'].items():
Expand Down Expand Up @@ -83,7 +84,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):
builder.pw.metadata = aiida_inputs["metadata"]

builder.kpoints = orm.KpointsData()

if pw_overrides["CONTROL"]["calculation"] in ["scf", "nscf"]:
builder.kpoints.set_kpoints_mesh(calc_params["kpts"])
elif pw_overrides["CONTROL"]["calculation"] == "bands":
Expand Down Expand Up @@ -123,8 +124,8 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
nscf = orm.load_node(val["workchain"])
if not nscf:
raise ValueError("No nscf step found.")


aiida_inputs = step_data['configuration']

codes = {
Expand Down Expand Up @@ -160,11 +161,11 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
t = np.where(k_linear==coords)[0]
k_labels.append([t[0],label])
k_coords.append(special_k[label].tolist())

kpoints_path = orm.KpointsData()
kpoints_path.set_kpoints(k_path,labels=k_labels,cartesian=False)
builder.kpoint_path = kpoints_path


# Start parameters and projections setting using the Wannier90Calculator data.
params = builder.wannier90.wannier90.parameters.get_dict()
Expand Down Expand Up @@ -238,10 +239,61 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)

return builder, step_data


def get_projwfc_builder_from_ase(projwfc_calculator, step_data=None):
from aiida import load_profile, orm
from aiida_quantumespresso.calculations.projwfc import ProjwfcCalculation

load_profile()

"""
Convert a `ProjwfcCalculator` into an AiiDA `ProjwfcCalculation
"""

aiida_inputs = step_data["configuration"]
calc_params = projwfc_calculator._parameters

# TODO: This is not needed, if we can just pass `orm.Dict(calc_params)` to the builder
from koopmans.settings import ProjwfcSettingsDict

projwfc_parameters = {}
projwfcsettingsdict = ProjwfcSettingsDict()
projwfc_keys = (
projwfcsettingsdict.valid
+ list(projwfcsettingsdict.defaults.keys())
+ projwfcsettingsdict.are_paths
)
for k in projwfc_keys:
if k in calc_params.keys() and k not in ALL_BLOCKED_KEYWORDS:
projwfc_parameters[k] = calc_params[k]

projwfc_parameters['filpdos'] = 'aiida'

builder = ProjwfcCalculation.get_builder()
builder.code = orm.load_code(aiida_inputs["projwfc_code"])
builder.parameters = orm.Dict({"PROJWFC": projwfc_parameters})
builder.metadata = aiida_inputs["metadata"]

parent_calculators = [
f[0].uid for f in projwfc_calculator.linked_files.values() if f[0] is not None
]

if len(set(parent_calculators)) > 1:
raise ValueError("More than one parent calculator found.")
elif len(set(parent_calculators)) == 1:
if "remote_folder" in step_data["steps"][parent_calculators[0]]:
builder.parent_folder = orm.load_node(
step_data["steps"][parent_calculators[0]]["remote_folder"]
)

return builder


## Here we have the mapping for the calculators initialization. used in the `aiida_calculate_trigger`.
mapping_calculators = {
".pwo" : get_PwBaseWorkChain_from_ase,
".wout": get_Wannier90BandsWorkChain_builder_from_ase,
".pro": get_projwfc_builder_from_ase,
#".w2ko": from_wann2kc_to_KcwCalculation,
#".kso": from_kcwscreen_to_KcwCalculation,
#".kho": from_kcwham_to_KcwCalculation,
Expand All @@ -250,13 +302,13 @@ def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None)
# read the output file, mimicking the read_results method of ase-koopmans: https://github.com/elinscott/ase_koopmans/blob/master/ase/calculators/espresso/_espresso.py
def read_output_file(calculator, retrieved, inner_remote_folder=None):
"""
Read the output file of a calculator using ASE io.read() method but parsing the AiiDA outputs.
Read the output file of a calculator using ASE io.read() method but parsing the AiiDA outputs.
NB: calculator (ASE) should contain the related AiiDA workchain as attribute.
"""
#if inner_remote_folder:
# if inner_remote_folder:
# retrieved = inner_remote_folder
#else:
#retrieved = workchain.outputs.retrieved
# else:
# retrieved = workchain.outputs.retrieved
with tempfile.TemporaryDirectory() as dirpath:
# Open the output file from the AiiDA storage and copy content to the temporary file
for filename in retrieved.base.repository.list_object_names():
Expand All @@ -267,4 +319,34 @@ def read_output_file(calculator, retrieved, inner_remote_folder=None):
with retrieved.open(filename, 'rb') as handle:
temp_file.write_bytes(handle.read())
output = io.read(temp_file)
return output
return output


def dump_pdos_outputs(calculator, retrieved):
"""
Dump the `pdos` output files of a projwfc.x calculation run via AiiDA to a temporary directory which is returned.
"""

output_dir = calculator.directory / pathlib.Path(tempfile.mkdtemp()).parts[-1]
output_dir.mkdir(exist_ok=True, parents=True)

for filename in retrieved.base.repository.list_object_names():
if ".pdos" in filename:
# Create the file with the desired name
output_file = pathlib.Path(output_dir) / (
f"{calculator.parameters.filpdos}." + filename.replace("aiida.", "")
)
with retrieved.open(filename, "rb") as handle:
output_file.write_bytes(handle.read())

return output_dir


def delete_directory(dir_path):
dir_path = pathlib.Path(dir_path)
for child in dir_path.iterdir():
if child.is_dir():
delete_directory(child)
else:
child.unlink()
dir_path.rmdir()

0 comments on commit b301dc6

Please sign in to comment.