From 07edbf90ab136f5d2ed60a6266eab28295c5e49b Mon Sep 17 00:00:00 2001 From: mikibonacci Date: Tue, 3 Dec 2024 08:10:30 +0000 Subject: [PATCH] Support for the new engine class in koopmans. --- pyproject.toml | 1 - src/aiida_koopmans/data/utils.py | 4 +- src/aiida_koopmans/engine/aiida.py | 164 +++++++++++++++++ src/aiida_koopmans/helpers.py | 74 +++++--- src/aiida_koopmans/parsers.py | 58 ------ src/aiida_koopmans/utils.py | 276 +++++++++++++++++++++++++++++ 6 files changed, 492 insertions(+), 85 deletions(-) create mode 100644 src/aiida_koopmans/engine/aiida.py delete mode 100644 src/aiida_koopmans/parsers.py create mode 100644 src/aiida_koopmans/utils.py diff --git a/pyproject.toml b/pyproject.toml index c6d5847..d177c6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ dependencies = [ "voluptuous", "aiida-quantumespresso~=4.6.0", "aiida-wannier90-workflows @ git+https://github.com/aiidateam/aiida-wannier90-workflows", - "koopmans @ git+https://github.com/mikibonacci/koopmans@feature/ase-to-pw" ] [project.urls] diff --git a/src/aiida_koopmans/data/utils.py b/src/aiida_koopmans/data/utils.py index 75d6431..4d822df 100644 --- a/src/aiida_koopmans/data/utils.py +++ b/src/aiida_koopmans/data/utils.py @@ -14,7 +14,7 @@ def generate_singlefiledata(filename, flines): return file -def produce_wannier90_files(calc_w90,merge_directory_name,method="dfpt"): +def produce_wannier90_files(calc_w90,label,method="dfpt"): """producing the wannier90 files in the case of just one occ and/or one emp blocks. Args: @@ -34,7 +34,7 @@ def produce_wannier90_files(calc_w90,merge_directory_name,method="dfpt"): standard_dictionary = {'hr_dat':hr_singlefile, "u_mat": u_singlefile, "centres_xyz": centres_singlefile} - if method == 'dfpt' and merge_directory_name == "emp": + if method == 'dfpt' and label == "emp": u_dis_file = calc_w90.wchain.outputs.wannier90.retrieved.get_object_content('aiida' + '_u_dis.mat') u_dis_singlefile = generate_singlefiledata('aiida' + '_u_dis.mat', u_dis_file) standard_dictionary["u_dis_mat"] = u_dis_singlefile diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py new file mode 100644 index 0000000..52adf6f --- /dev/null +++ b/src/aiida_koopmans/engine/aiida.py @@ -0,0 +1,164 @@ +from koopmans.engines.engine import Engine +from koopmans.step import Step +from koopmans.calculators import Calc +from koopmans.pseudopotentials import read_pseudo_file + +from aiida.engine import run_get_node, submit + +from aiida_koopmans.utils import * + +from aiida_pseudo.data.pseudo import UpfData + +import time + +import dill as pickle + +from aiida import orm, load_profile +load_profile() + +class AiiDAEngine(Engine): + + """ + Step data is a dictionary containing the following information: + step_data = {calc.directory: {'workchain': workchain, 'remote_folder': remote_folder}} + and any other info we need for AiiDA. + """ + + def __init__(self, *args, **kwargs): + self.blocking = kwargs.pop('blocking', True) + self.step_data = { # TODO: change to a better name + 'configuration': kwargs.pop('configuration', None), + 'steps': {} + } + + # here we add the logic to populate configuration by default + # 1. we look for codes stored in AiiDA at localhost, e.g. pw-version@localhost, + # 2. we look for codes in the PATH, + # 3. if we don't find the code in AiiDA db but in the PATH, we store it in AiiDA db. + # 4. if we don't find the code in AiiDA db and in the PATH and not configuration is provided, we raise an error. + if self.step_data['configuration'] is None: + raise NotImplementedError("Configuration not provided") + + # 5. if no resource info in configuration, we try to look at PARA_PREFIX env var. + + + super().__init__(*args, **kwargs) + + + def _run_steps(self, steps: tuple[Step, ...]) -> None: + try: + with open('step_data.pkl', 'rb') as f: + self.step_data = pickle.load(f) # this will overwrite the step_data[configuration], ie. if we change codes or res we will not see it if the file already exists. + except: + pass + + self.from_scratch = False + for step in steps: + # self._step_running_message(step) + if step.directory in self.step_data['steps']: + continue + elif step.prefix in ['wannier90_preproc', 'pw2wannier90']: + print(f'skipping {step.prefix} step') + continue + else: + self.from_scratch = True + + #step.run() + self.step_data['steps'][step.directory] = {} + builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails + running = submit(builder) + # running = aiidawrapperwchain.submit(builder) # in the non-blocking case. + self.step_data['steps'][step.directory] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder} + + #if self.from_scratch: + with open('step_data.pkl', 'wb') as f: + pickle.dump(self.step_data, f) + + if not self.blocking and self.from_scratch: + raise CalculationSubmitted("Calculation submitted to AiiDA, non blocking") + elif self.blocking: + for step in self.step_data['steps'].values(): + while not orm.load_node(step['workchain']).is_finished: + time.sleep(5) + + for step in steps: + # convert from AiiDA to ASE results and populate ASE calculator + # TOBE put in a separate function + if step.prefix in ['wannier90_preproc', 'pw2wannier90']: + continue + workchain = orm.load_node(self.step_data['steps'][step.directory]['workchain']) + if "remote_folder" in workchain.outputs: + self.step_data['steps'][step.directory]['remote_folder'] = workchain.outputs.remote_folder.pk + output = None + if step.ext_out == ".wout": + output = read_output_file(step, workchain.outputs.wannier90.retrieved) + elif step.ext_out in [".pwo",".kho"]: + output = read_output_file(step, workchain.outputs.retrieved) + if hasattr(output.calc, 'kpts'): + step.kpts = output.calc.kpts + else: + output = read_output_file(step, workchain.outputs.retrieved) + if step.ext_out in [".pwo",".wout",".kso",".kho"]: + step.calc = output.calc + step.results = output.calc.results + step.generate_band_structure(nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) + + self._step_completed_message(step) + + # If we reached here, all future steps should be performed from scratch + self.from_scratch = True + + # dump again to have update the information + with open('step_data.pkl', 'wb') as f: + pickle.dump(self.step_data, f) + + return + + def load_old_calculator(self, calc: Calc): + raise NotImplementedError # load_old_calculator(calc) + + def get_pseudo_data(self, workflow): + pseudo_data = {} + symbols_list = [] + for symbol in workflow.pseudopotentials.keys(): + symbols_list.append(symbol) + + qb = orm.QueryBuilder() + qb.append(orm.Group, filters={'label': {'==': 'pseudo_group'}}, tag='pseudo_group') + qb.append(UpfData, filters={'attributes.element': {'in': symbols_list}}, with_group='pseudo_group') + + for pseudo in qb.all(): + with tempfile.TemporaryDirectory() as dirpath: + temp_file = pathlib.Path(dirpath) / pseudo[0].attributes.element + '.upf' + with pseudo[0].open(pseudo[0].attributes.element + '.upf', 'wb') as handle: + temp_file.write_bytes(handle.read()) + pseudo_data[pseudo[0].attributes.element] = read_pseudo_file(temp_file) + + return pseudo_data + + +def load_old_calculator(calc): + # This is a separate function so that it can be imported by other engines + loaded_calc = calc.__class__.fromfile(calc.directory / calc.prefix) + + if loaded_calc.is_complete(): + # If it is complete, load the results + calc.results = loaded_calc.results + + # Check the convergence of the calculation + calc.check_convergence() + + # Load k-points if relevant + if hasattr(loaded_calc, 'kpts'): + calc.kpts = loaded_calc.kpts + + if isinstance(calc, ReturnsBandStructure): + calc.generate_band_structure() + + if isinstance(calc, ProjwfcCalculator): + calc.generate_dos() + + if isinstance(calc, PhCalculator): + calc.read_dynG() + + return loaded_calc \ No newline at end of file diff --git a/src/aiida_koopmans/helpers.py b/src/aiida_koopmans/helpers.py index 80a1f70..e7c1573 100644 --- a/src/aiida_koopmans/helpers.py +++ b/src/aiida_koopmans/helpers.py @@ -23,7 +23,7 @@ from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys from aiida_koopmans.calculations.kcw import KcwCalculation -from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata +from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files """ ASE calculator MUST have `wchain` attribute (the related AiiDA WorkChain) to be able to use these functions! @@ -272,15 +272,16 @@ def from_wann2kc_to_KcwCalculation(wann2kc_calculator): builder.metadata = wann2kc_calculator.parameters.mode["metadata_kcw"] builder.parent_folder = wann2kc_calculator.parent_folder - if hasattr(wann2kc_calculator, "w90_files"): - builder.wann_emp_u_mat = wann2kc_calculator.w90_files["emp"]["u_mat"] - builder.wann_emp_u_dis_mat = wann2kc_calculator.w90_files["emp"][ + if hasattr(wann2kc_calculator, "wannier90_files"): + builder.wann_emp_u_mat = wann2kc_calculator.wannier90_files["emp"]["u_mat"] + builder.wann_u_mat = wann2kc_calculator.wannier90_files["occ"]["u_mat"] + builder.wann_emp_u_dis_mat = wann2kc_calculator.wannier90_files["emp"][ "u_dis_mat" ] - builder.wann_centres_xyz = wann2kc_calculator.w90_files["occ"][ + builder.wann_centres_xyz = wann2kc_calculator.wannier90_files["occ"][ "centres_xyz" ] - builder.wann_emp_centres_xyz = wann2kc_calculator.w90_files["emp"][ + builder.wann_emp_centres_xyz = wann2kc_calculator.wannier90_files["emp"][ "centres_xyz" ] @@ -354,14 +355,14 @@ def from_kcwham_to_KcwCalculation(kcw_calculator): builder.metadata = kcw_calculator.parameters.mode["metadata_kcw"] builder.parent_folder = kcw_calculator.parent_folder - if hasattr(kcw_calculator, "w90_files") and control_dict.get( + if hasattr(kcw_calculator, "wannier90_files") and control_dict.get( "read_unitary_matrix", False ): - builder.wann_u_mat = kcw_calculator.w90_files["occ"]["u_mat"] - builder.wann_emp_u_mat = kcw_calculator.w90_files["emp"]["u_mat"] - builder.wann_emp_u_dis_mat = kcw_calculator.w90_files["emp"]["u_dis_mat"] - builder.wann_centres_xyz = kcw_calculator.w90_files["occ"]["centres_xyz"] - builder.wann_emp_centres_xyz = kcw_calculator.w90_files["emp"][ + builder.wann_u_mat = kcw_calculator.wannier90_files["occ"]["u_mat"] + builder.wann_emp_u_mat = kcw_calculator.wannier90_files["emp"]["u_mat"] + builder.wann_emp_u_dis_mat = kcw_calculator.wannier90_files["emp"]["u_dis_mat"] + builder.wann_centres_xyz = kcw_calculator.wannier90_files["occ"]["centres_xyz"] + builder.wann_emp_centres_xyz = kcw_calculator.wannier90_files["emp"][ "centres_xyz" ] @@ -438,14 +439,14 @@ def from_kcwscreen_to_KcwCalculation(kcw_calculator): builder.metadata = kcw_calculator.parameters.mode["metadata_kcw"] builder.parent_folder = kcw_calculator.parent_folder - if hasattr(kcw_calculator, "w90_files") and control_dict.get( + if hasattr(kcw_calculator, "wannier90_files") and control_dict.get( "read_unitary_matrix", False ): - builder.wann_u_mat = kcw_calculator.w90_files["occ"]["u_mat"] - builder.wann_emp_u_mat = kcw_calculator.w90_files["emp"]["u_mat"] - builder.wann_emp_u_dis_mat = kcw_calculator.w90_files["emp"]["u_dis_mat"] - builder.wann_centres_xyz = kcw_calculator.w90_files["occ"]["centres_xyz"] - builder.wann_emp_centres_xyz = kcw_calculator.w90_files["emp"][ + builder.wann_u_mat = kcw_calculator.wannier90_files["occ"]["u_mat"] + builder.wann_emp_u_mat = kcw_calculator.wannier90_files["emp"]["u_mat"] + builder.wann_emp_u_dis_mat = kcw_calculator.wannier90_files["emp"]["u_dis_mat"] + builder.wann_centres_xyz = kcw_calculator.wannier90_files["occ"]["centres_xyz"] + builder.wann_emp_centres_xyz = kcw_calculator.wannier90_files["emp"][ "centres_xyz" ] @@ -661,12 +662,21 @@ def wrapper_aiida_trigger(self): if self.parameters.mode == "ase": return _fetch_linked_files(self) else: # if pseudo linking, src_calc = None + self.wannier90_files = {} for dest_filename, (src_calc, src_filename, symlink, recursive_symlink, overwrite) in self.linked_files.items(): # check the we have only one src_calc!!! if hasattr(src_calc,"wchain"): - self.parent_folder = src_calc.wchain.outputs.remote_folder - elif hasattr(src_calc,"dst_file"): - pass #linik w90_files. + # semi-complex logic to have the w90 files for the wann2kc: + if "wannier90" in dest_filename: + if "emp" in dest_filename: + self.wannier90_files["emp"] = produce_wannier90_files(src_calc,"emp") + else: + self.wannier90_files["occ"] = produce_wannier90_files(src_calc,"occ") + else: + self.parent_folder = src_calc.wchain.outputs.remote_folder + elif hasattr(src_calc,"wannier90_files"): + pass #link wannier90_files in case we merged multiple Wannierization. + if self.wannier90_files == {}: delattr(self, "wannier90_files") return wrapper_aiida_trigger # get files to manipulate further. @@ -730,22 +740,38 @@ def wrapper_aiida_trigger(self): return wrapper_aiida_trigger def aiida_merge_wannier_files_trigger(merge_wannier_files): - # needed to populate the self.w90_files dictionary in the wannierize workflow. + # needed to populate the self.wannier90_files dictionary in the wannierize workflow. @functools.wraps(merge_wannier_files) def wrapper_aiida_trigger(self, block, merge_directory, prefix): merge_wannier_files(self,block, merge_directory, prefix) if self.parameters.mode == "ase": return else: - self.w90_files[merge_directory] = { + self.wannier90_files[merge_directory] = { 'hr_dat': self.merge_hr_proc.singlefiledata, } del self.merge_hr_proc.singlefiledata if self.parameters.method == 'dfpt': - self.w90_files[merge_directory].update({ + self.wannier90_files[merge_directory].update({ "u_mat": self.merge_u_proc.singlefiledata, "centres_xyz": self.merge_centers_proc.singlefiledata, }) del self.merge_u_proc.singlefiledata del self.merge_centers_proc.singlefiledata + return wrapper_aiida_trigger + +def aiida_trigger_run_calculator(run_calculator): + # This wraps the run_calculator method. + @functools.wraps(run_calculator) + def wrapper_aiida_trigger(self, calc): + if calc.prefix in ["wannier90_preproc","pw2wannier90"]: + return + elif calc.prefix in ["wannier90"]: + from koopmans import calculators + calc_nscf = [c for c in self.calculations if isinstance( + c, calculators.PWCalculator) and c.parameters.calculation == 'nscf'][-1] + self.link(calc_nscf, calc_nscf.parameters.outdir, calc, "") + return run_calculator(self,calc) + else: + return run_calculator(self,calc) return wrapper_aiida_trigger \ No newline at end of file diff --git a/src/aiida_koopmans/parsers.py b/src/aiida_koopmans/parsers.py deleted file mode 100644 index 9f1b54b..0000000 --- a/src/aiida_koopmans/parsers.py +++ /dev/null @@ -1,58 +0,0 @@ -""" -Parsers provided by aiida_koopmans. - -Register parsers via the "aiida.parsers" entry point in setup.json. -""" - -from aiida.common import exceptions -from aiida.engine import ExitCode -from aiida.orm import SinglefileData -from aiida.parsers.parser import Parser -from aiida.plugins import CalculationFactory - -DiffCalculation = CalculationFactory("koopmans") - - -class DiffParser(Parser): - """ - Parser class for parsing output of calculation. - """ - - def __init__(self, node): - """ - Initialize Parser instance - - Checks that the ProcessNode being passed was produced by a DiffCalculation. - - :param node: ProcessNode of calculation - :param type node: :class:`aiida.orm.nodes.process.process.ProcessNode` - """ - super().__init__(node) - if not issubclass(node.process_class, DiffCalculation): - raise exceptions.ParsingError("Can only parse DiffCalculation") - - def parse(self, **kwargs): - """ - Parse outputs, store results in database. - - :returns: an exit code, if parsing fails (or nothing if parsing succeeds) - """ - output_filename = self.node.get_option("output_filename") - - # Check that folder content is as expected - files_retrieved = self.retrieved.list_object_names() - files_expected = [output_filename] - # Note: set(A) <= set(B) checks whether A is a subset of B - if not set(files_expected) <= set(files_retrieved): - self.logger.error( - f"Found files '{files_retrieved}', expected to find '{files_expected}'" - ) - return self.exit_codes.ERROR_MISSING_OUTPUT_FILES - - # add output file - self.logger.info(f"Parsing '{output_filename}'") - with self.retrieved.open(output_filename, "rb") as handle: - output_node = SinglefileData(file=handle) - self.out("koopmans", output_node) - - return ExitCode(0) diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py new file mode 100644 index 0000000..c259b2d --- /dev/null +++ b/src/aiida_koopmans/utils.py @@ -0,0 +1,276 @@ + +import shutil +import pathlib +import tempfile + +import numpy as np +import functools + +from aiida.common.exceptions import NotExistent +from aiida.orm import Code, Computer +from aiida_quantumespresso.calculations.pw import PwCalculation +from aiida_wannier90.calculations.wannier90 import Wannier90Calculation +from ase import io +from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys + +from aiida_koopmans.calculations.kcw import KcwCalculation +from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files + +LOCALHOST_NAME = "localhost-test" +KCW_BLOCKED_KEYWORDS = [t[1] for t in KcwCalculation._blocked_keywords] +PW_BLOCKED_KEYWORDS = [t[1] for t in PwCalculation._blocked_keywords] +WANNIER90_BLOCKED_KEYWORDS = [t[1] for t in Wannier90Calculation._BLOCKED_PARAMETER_KEYS] +ALL_BLOCKED_KEYWORDS = KCW_BLOCKED_KEYWORDS + PW_BLOCKED_KEYWORDS + WANNIER90_BLOCKED_KEYWORDS + [f'celldm({i})' for i in range (1,7)] + +def get_builder_from_ase(calculator, step_data=None): + return mapping_calculators[calculator.ext_out](calculator, step_data) + +# Pw calculator. +def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): + from aiida import load_profile, orm + from aiida_quantumespresso.common.types import ElectronicType + from aiida_quantumespresso.workflows.pw.base import PwBaseWorkChain, PwCalculation + + load_profile() + + """ + We should check automatically on the accepted keywords in PwCalculation and where are. Should be possible. + we suppose that the calculator has an attribute called mode e.g. + + pw_calculator.parameters.mode = { + "pw_code": "pw-7.2-ok@localhost", + "metadata": { + "options": { + "max_wallclock_seconds": 3600, + "resources": { + "num_machines": 1, + "num_mpiprocs_per_machine": 1, + "num_cores_per_mpiproc": 1 + }, + "custom_scheduler_commands": "export OMP_NUM_THREADS=1" + } + } + } + """ + aiida_inputs = step_data['configuration'] + calc_params = pw_calculator._parameters + structure = orm.StructureData(ase=pw_calculator.atoms) # TODO: only one sdata, stored in the step_data dict. but some cases have output structure diff from input. + + pw_overrides = { + "CONTROL": {}, + "SYSTEM": {"nosym": True, "noinv": True}, + "ELECTRONS": {}, + } + + for k in pw_keys['control']: + if k in calc_params.keys() and k not in ALL_BLOCKED_KEYWORDS: + pw_overrides["CONTROL"][k] = calc_params[k] + + for k in pw_keys['system']: + if k in calc_params.keys() and k not in [ALL_BLOCKED_KEYWORDS, 'tot_magnetization']: + pw_overrides["SYSTEM"][k] = calc_params[k] + + for k in pw_keys['electrons']: + if k in calc_params.keys() and k not in ALL_BLOCKED_KEYWORDS: + pw_overrides["ELECTRONS"][k] = calc_params[k] + + builder = PwBaseWorkChain.get_builder_from_protocol( + code=aiida_inputs["pw_code"], + structure=structure, + overrides={ + "pseudo_family": "PseudoDojo/0.4/LDA/SR/standard/upf", # TODO: automatic store of pseudos from koopmans folder, if not. + "pw": {"parameters": pw_overrides}, + }, + electronic_type=ElectronicType.INSULATOR, + ) + builder.pw.metadata = aiida_inputs["metadata"] + + builder.kpoints = orm.KpointsData() + + if pw_overrides["CONTROL"]["calculation"] in ["scf", "nscf"]: + builder.kpoints.set_kpoints_mesh(calc_params["kpts"]) + elif pw_overrides["CONTROL"]["calculation"] == "bands": + # here we need explicit kpoints + builder.kpoints.set_kpoints(calc_params["kpts"].kpts,cartesian=False) # TODO: check cartesian false is correct. + + parent_calculators = [f[0].directory for f in pw_calculator.linked_files.values() if f[0] is not None] + if len(set(parent_calculators)) > 1: + raise ValueError("More than one parent calculator found.") + elif len(set(parent_calculators)) == 1: + if "remote_folder" in step_data['steps'][parent_calculators[0]]: + builder.pw.parent_folder = orm.load_node(step_data['steps'][parent_calculators[0]]["remote_folder"]) + + return builder + +def get_Wannier90BandsWorkChain_builder_from_ase(w90_calculator, step_data=None): + # get the builder from WannierizeWorkflow, but after we already initialized a Wannier90Calculator. + # in this way we have everything we need for each different block of the wannierization step. + + from aiida import load_profile, orm + from aiida_wannier90_workflows.common.types import WannierProjectionType + from aiida_wannier90_workflows.utils.kpoints import get_explicit_kpoints_from_mesh + from aiida_wannier90_workflows.utils.workflows.builder.serializer import ( + print_builder, + ) + from aiida_wannier90_workflows.utils.workflows.builder.setter import ( + set_kpoints, + set_num_bands, + set_parallelization, + ) + from aiida_wannier90_workflows.utils.workflows.builder.submit import ( + submit_and_add_group, + ) + from aiida_wannier90_workflows.workflows import Wannier90BandsWorkChain + load_profile() + + #nscf = w90_calculator.parent_folder.creator.caller # PwBaseWorkChain + nscf = None + for step, val in step_data['steps'].items(): + if "nscf" in str(step): + nscf = orm.load_node(val["workchain"]) + if not nscf: + raise ValueError("No nscf step found.") + + + aiida_inputs = step_data['configuration'] + + codes = { + "pw": aiida_inputs["pw_code"], + "pw2wannier90": aiida_inputs["pw2wannier90_code"], + #"projwfc": aiida_inputs["projwfc_code"], + "wannier90": aiida_inputs["wannier90_code"], + } + + builder = Wannier90BandsWorkChain.get_builder_from_protocol( + codes=codes, + structure=nscf.inputs.pw.structure, + pseudo_family="PseudoDojo/0.4/LDA/SR/standard/upf", + protocol="fast", + projection_type=WannierProjectionType.ANALYTIC, + print_summary=False, + ) + + # Use nscf explicit kpoints + kpoints = orm.KpointsData() + kpoints.set_cell_from_structure(builder.structure) + kpoints.set_kpoints(nscf.outputs.output_band.get_array('kpoints'),cartesian=False) + builder.wannier90.wannier90.kpoints = kpoints + + # set kpath using the WannierizeWFL data. + k_coords = [] + k_labels = [] + print(w90_calculator.kpts) + k_path=w90_calculator.parameters.kpoint_path.kpts + special_k = w90_calculator.parameters.kpoint_path.todict()["special_points"] + k_linear,special_k_coords,special_k_labels = w90_calculator.parameters.kpoint_path.get_linear_kpoint_axis() + t=0 + for coords,label in list(zip(special_k_coords,special_k_labels)): + t = np.where(k_linear==coords)[0] + k_labels.append([t[0],label]) + k_coords.append(special_k[label].tolist()) + + kpoints_path = orm.KpointsData() + kpoints_path.set_kpoints(k_path,labels=k_labels,cartesian=False) + builder.kpoint_path = kpoints_path + + + # Start parameters and projections setting using the Wannier90Calculator data. + params = builder.wannier90.wannier90.parameters.get_dict() + + del builder.scf + del builder.nscf + del builder.projwfc + + for k,v in w90_calculator.parameters.items(): + if k not in ["kpoints","kpoint_path","projections"]: + params[k] = v + + # projections in wannier90 format: + converted_projs = [] + for proj in w90_calculator.todict()['_parameters']["projections"]: + # for now we support only the following conversion: + # proj={'fsite': [0.0, 0.0, 0.0], 'ang_mtm': 'sp3'} ==> converted_proj="f=0.0,0.0,0.0:sp3" + if "fsite" in proj.keys(): + position = "f="+str(proj["fsite"]).replace("[","").replace("]","").replace(" ","") + elif "site" in proj.keys(): + position = str(proj["site"]) + orbital = proj["ang_mtm"] + converted_proj = position+":"+orbital + converted_projs.append(converted_proj) + + builder.wannier90.wannier90.projections = orm.List(list=converted_projs) + params.pop('auto_projections', None) # Uncomment this if you want analytic atomic projections + + ## END explicit atomic projections: + + # putting the fermi energy to make it work. + try: + fermi_energy = nscf.outputs.output_parameters.get_dict()["fermi_energy_up"] + except: + fermi_energy = nscf.outputs.output_parameters.get_dict()["fermi_energy"] + params["fermi_energy"] = fermi_energy + + params = orm.Dict(dict=params) + builder.wannier90.wannier90.parameters = params + + #resources + builder.pw2wannier90.pw2wannier90.metadata = aiida_inputs["metadata"] + + default_w90_metadata = { + "options": { + "max_wallclock_seconds": 3600, + "resources": { + "num_machines": 1, + "num_mpiprocs_per_machine": 1, + "num_cores_per_mpiproc": 1 + }, + "custom_scheduler_commands": "export OMP_NUM_THREADS=1" + } + } + builder.wannier90.wannier90.metadata = aiida_inputs.get('metadata_w90', default_w90_metadata) + + builder.pw2wannier90.pw2wannier90.parent_folder = nscf.outputs.remote_folder + + # for now try this, as the get_fermi_energy_from_nscf + get_homo_lumo does not work for fixed occ. + # maybe add some parsing (for fixed occ) in the aiida-wannier90-workflows/src/aiida_wannier90_workflows/utils/workflows/pw.py + builder.wannier90.shift_energy_windows = False + + # adding pw2wannier90 parameters, required here. We should do in overrides. + params_pw2wannier90 = builder.pw2wannier90.pw2wannier90.parameters.get_dict() + params_pw2wannier90['inputpp']["wan_mode"] = "standalone" + if nscf.inputs.pw.parameters.get_dict()["SYSTEM"]["nspin"]>1: params_pw2wannier90['inputpp']["spin_component"] = "up" + builder.pw2wannier90.pw2wannier90.parameters = orm.Dict(dict=params_pw2wannier90) + + + return builder + +## Here we have the mapping for the calculators initialization. used in the `aiida_calculate_trigger`. +mapping_calculators = { + ".pwo" : get_PwBaseWorkChain_from_ase, + ".wout": get_Wannier90BandsWorkChain_builder_from_ase, + #".w2ko": from_wann2kc_to_KcwCalculation, + #".kso": from_kcwscreen_to_KcwCalculation, + #".kho": from_kcwham_to_KcwCalculation, +} + +# read the output file, mimicking the read_results method of ase-koopmans: https://github.com/elinscott/ase_koopmans/blob/master/ase/calculators/espresso/_espresso.py +def read_output_file(calculator, retrieved, inner_remote_folder=None): + """ + Read the output file of a calculator using ASE io.read() method but parsing the AiiDA outputs. + NB: calculator (ASE) should contain the related AiiDA workchain as attribute. + """ + #if inner_remote_folder: + # retrieved = inner_remote_folder + #else: + #retrieved = workchain.outputs.retrieved + with tempfile.TemporaryDirectory() as dirpath: + # Open the output file from the AiiDA storage and copy content to the temporary file + for filename in retrieved.base.repository.list_object_names(): + if '.out' in filename or '.wout' in filename: + # Create the file with the desired name + readable_filename = calculator.label.split("/")[-1]+calculator.ext_out + temp_file = pathlib.Path(dirpath) / readable_filename + with retrieved.open(filename, 'rb') as handle: + temp_file.write_bytes(handle.read()) + output = io.read(temp_file) + return output \ No newline at end of file