diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py index e2e89c7..60c2fdf 100644 --- a/src/aiida_koopmans/engine/aiida.py +++ b/src/aiida_koopmans/engine/aiida.py @@ -1,6 +1,6 @@ from koopmans.engines.engine import Engine from koopmans.step import Step -from koopmans.calculators import Calc +from koopmans.calculators import Calc, ProjwfcCalculator from koopmans.pseudopotentials import read_pseudo_file from koopmans.status import Status from koopmans.files import FilePointer @@ -17,6 +17,9 @@ import time import dill as pickle +import pathlib +import tempfile +import fnmatch from aiida import orm, load_profile load_profile() @@ -60,7 +63,7 @@ def run(self, step: Step): if step.prefix in ['wannier90_preproc', 'pw2wannier90']: self.set_status(step, Status.COMPLETED) return - + self.step_data['steps'][step.uid] = {} # maybe not needed builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails @@ -139,6 +142,8 @@ def load_results(self, step: Step) -> None: self.load_step_data() if isinstance(step, Process): + step.load_outputs() + self._step_completed_message(step) return if step.prefix in ['wannier90_preproc', 'pw2wannier90']: @@ -156,13 +161,15 @@ def load_results(self, step: Step) -> None: step.kpts = output.calc.kpts else: output = read_output_file(step, workchain.outputs.retrieved) - if step.ext_out in [".pwo",".wout",".kso",".kho"]: + + + if step.ext_out in [".pwo",".pro",".wout",".kso",".kho"]: step.calc = output.calc step.results = output.calc.results - if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) + #if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) - self._step_completed_message(step) + ''' if step.ext_out in [".pro"]: pdos_dir = dump_pdos_outputs(step, workchain.outputs.retrieved) @@ -179,9 +186,12 @@ def load_results(self, step: Step) -> None: delete_directory(pdos_dir.parent) step.directory = prev_dir - self._step_completed_message(step) - + self._step_completed_message(step + ''' + step._post_run() self.dump_step_data() + self._step_completed_message(step) + def load_old_calculator(self, calc: Calc): raise NotImplementedError # load_old_calculator(calc) @@ -209,6 +219,9 @@ def get_pseudopotential(self, library: str, element: str): return pseudo_data def read(self, file: FilePointer, binary=False) -> str | bytes: + if isinstance(file[0], Process): + singlefiledata = orm.load_node(self.step_data['steps'][file[0].uid][str(file.name)]) + return singlefiledata.get_content(mode='rb') workchain = orm.load_node(self.step_data['steps'][file[0].uid]['workchain']) filename = str(file[1]).replace(file[0].prefix, 'aiida') if 'wannier90' in file[0].prefix: @@ -225,19 +238,31 @@ def read(self, file: FilePointer, binary=False) -> str | bytes: def write(self, content: str | bytes, file: FilePointer) -> None: if 'inputs.pkl' in str(file[1]): return - if isinstance(file[0], Process): - filename = file[0].inputs.dst_file - else: - filename = str(file[1]).replace(file[0].prefix, 'aiida') + + filename = str(file.name) if isinstance(content, bytes): - # skip the dumping of the *out.pkl file, we don't want as SinglefileData - return - singlefile = orm.SinglefileData.from_string(content, filename) + singlefile = orm.SinglefileData.from_bytes(content, filename) + else: + singlefile = orm.SinglefileData.from_string(content, filename) singlefile.store() - self.step_data['steps'][file[0].uid][str(filename)] = singlefile.pk + self.step_data['steps'][file[0].uid][filename] = singlefile.pk return singlefile - def glob(self, pattern: FilePointer, recursive=False) -> Generator[FilePointer, None, None]: - raise NotImplementedError() + def glob(self, directory: FilePointer, pattern: str, recursive: bool = False) -> Generator[FilePointer, None, None]: + + workchain = orm.load_node(self.step_data['steps'][directory.parent.uid]['workchain']) + if 'wannier90' in getattr(directory.parent, 'prefix', ''): + listnames = workchain.outputs.wannier90.retrieved.base.repository.list_object_names() + else: + listnames = workchain.outputs.retrieved.base.repository.list_object_names() + for name in listnames: + tomatch = str(directory.name / pattern) + if hasattr(directory.parent, 'prefix'): + tomatch = tomatch.replace(directory.parent.prefix, 'aiida') + if isinstance(directory.parent, ProjwfcCalculator): # TODO: this is a hack, we need to find a better way to do this. + tomatch = tomatch.replace(directory.parent.parameters.filpdos, 'aiida') + if fnmatch.fnmatch(name, tomatch): + yield FilePointer(directory.parent, pathlib.Path(name)) + \ No newline at end of file diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py index 007af83..7406fa3 100644 --- a/src/aiida_koopmans/utils.py +++ b/src/aiida_koopmans/utils.py @@ -42,8 +42,8 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): structure = None parent_folder = None - for step, val in step_data['steps'].items(): - if "scf" in str(step) and ("nscf" in pw_calculator.uid or "bands" in pw_calculator.uid): + for step_uid, val in step_data['steps'].items(): + if "scf" in step_uid and ("nscf" in pw_calculator.uid or "bands" in pw_calculator.uid): scf = orm.load_node(val["workchain"]) structure = scf.inputs.pw.structure parent_folder = scf.outputs.remote_folder