Skip to content

Commit

Permalink
engine.glob method and more.
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed Dec 6, 2024
1 parent 5272b38 commit cef1911
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 19 deletions.
59 changes: 42 additions & 17 deletions src/aiida_koopmans/engine/aiida.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from koopmans.engines.engine import Engine
from koopmans.step import Step
from koopmans.calculators import Calc
from koopmans.calculators import Calc, ProjwfcCalculator
from koopmans.pseudopotentials import read_pseudo_file
from koopmans.status import Status
from koopmans.files import FilePointer
Expand All @@ -17,6 +17,9 @@
import time

import dill as pickle
import pathlib
import tempfile
import fnmatch

from aiida import orm, load_profile
load_profile()
Expand Down Expand Up @@ -60,7 +63,7 @@ def run(self, step: Step):
if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
self.set_status(step, Status.COMPLETED)
return

self.step_data['steps'][step.uid] = {} # maybe not needed

builder, self.step_data = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails
Expand Down Expand Up @@ -139,6 +142,8 @@ def load_results(self, step: Step) -> None:
self.load_step_data()

if isinstance(step, Process):
step.load_outputs()
self._step_completed_message(step)
return

if step.prefix in ['wannier90_preproc', 'pw2wannier90']:
Expand All @@ -156,13 +161,15 @@ def load_results(self, step: Step) -> None:
step.kpts = output.calc.kpts
else:
output = read_output_file(step, workchain.outputs.retrieved)
if step.ext_out in [".pwo",".wout",".kso",".kho"]:


if step.ext_out in [".pwo",".pro",".wout",".kso",".kho"]:
step.calc = output.calc
step.results = output.calc.results
if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))
#if step.ext_out == ".pwo": step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons']))

self._step_completed_message(step)

'''
if step.ext_out in [".pro"]:
pdos_dir = dump_pdos_outputs(step, workchain.outputs.retrieved)
Expand All @@ -179,9 +186,12 @@ def load_results(self, step: Step) -> None:
delete_directory(pdos_dir.parent)
step.directory = prev_dir
self._step_completed_message(step)

self._step_completed_message(step
'''
step._post_run()
self.dump_step_data()
self._step_completed_message(step)


def load_old_calculator(self, calc: Calc):
raise NotImplementedError # load_old_calculator(calc)
Expand Down Expand Up @@ -209,6 +219,9 @@ def get_pseudopotential(self, library: str, element: str):
return pseudo_data

def read(self, file: FilePointer, binary=False) -> str | bytes:
if isinstance(file[0], Process):
singlefiledata = orm.load_node(self.step_data['steps'][file[0].uid][str(file.name)])
return singlefiledata.get_content(mode='rb')
workchain = orm.load_node(self.step_data['steps'][file[0].uid]['workchain'])
filename = str(file[1]).replace(file[0].prefix, 'aiida')
if 'wannier90' in file[0].prefix:
Expand All @@ -225,19 +238,31 @@ def read(self, file: FilePointer, binary=False) -> str | bytes:
def write(self, content: str | bytes, file: FilePointer) -> None:
if 'inputs.pkl' in str(file[1]):
return
if isinstance(file[0], Process):
filename = file[0].inputs.dst_file
else:
filename = str(file[1]).replace(file[0].prefix, 'aiida')

filename = str(file.name)

if isinstance(content, bytes):
# skip the dumping of the *out.pkl file, we don't want as SinglefileData
return
singlefile = orm.SinglefileData.from_string(content, filename)
singlefile = orm.SinglefileData.from_bytes(content, filename)
else:
singlefile = orm.SinglefileData.from_string(content, filename)
singlefile.store()
self.step_data['steps'][file[0].uid][str(filename)] = singlefile.pk
self.step_data['steps'][file[0].uid][filename] = singlefile.pk
return singlefile

def glob(self, pattern: FilePointer, recursive=False) -> Generator[FilePointer, None, None]:
raise NotImplementedError()
def glob(self, directory: FilePointer, pattern: str, recursive: bool = False) -> Generator[FilePointer, None, None]:

workchain = orm.load_node(self.step_data['steps'][directory.parent.uid]['workchain'])
if 'wannier90' in getattr(directory.parent, 'prefix', ''):
listnames = workchain.outputs.wannier90.retrieved.base.repository.list_object_names()
else:
listnames = workchain.outputs.retrieved.base.repository.list_object_names()

for name in listnames:
tomatch = str(directory.name / pattern)
if hasattr(directory.parent, 'prefix'):
tomatch = tomatch.replace(directory.parent.prefix, 'aiida')
if isinstance(directory.parent, ProjwfcCalculator): # TODO: this is a hack, we need to find a better way to do this.
tomatch = tomatch.replace(directory.parent.parameters.filpdos, 'aiida')
if fnmatch.fnmatch(name, tomatch):
yield FilePointer(directory.parent, pathlib.Path(name))

4 changes: 2 additions & 2 deletions src/aiida_koopmans/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None):

structure = None
parent_folder = None
for step, val in step_data['steps'].items():
if "scf" in str(step) and ("nscf" in pw_calculator.uid or "bands" in pw_calculator.uid):
for step_uid, val in step_data['steps'].items():
if "scf" in step_uid and ("nscf" in pw_calculator.uid or "bands" in pw_calculator.uid):
scf = orm.load_node(val["workchain"])
structure = scf.inputs.pw.structure
parent_folder = scf.outputs.remote_folder
Expand Down

0 comments on commit cef1911

Please sign in to comment.