diff --git a/notebooks/h20_wtree.ipynb b/notebooks/h20_wtree.ipynb index 59a9dc7..7e63faf 100644 --- a/notebooks/h20_wtree.ipynb +++ b/notebooks/h20_wtree.ipynb @@ -199,7 +199,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ase import Atoms\n", + "from ase_koopmans import Atoms\n", "import copy\n", "\n", "H2O = Atoms(atoms,\n", diff --git a/notebooks/koop_insp.ipynb b/notebooks/koop_insp.ipynb index 82195ea..328ae7d 100644 --- a/notebooks/koop_insp.ipynb +++ b/notebooks/koop_insp.ipynb @@ -556,7 +556,7 @@ "source": [ "## Trying aiida-shell\n", "\n", - "1. from ase calculator to aiida-shell inputs\n", + "1. from ase_koopmans calculator to aiida-shell inputs\n", "\n", "The first thing to do is to write the input file." ] diff --git a/notebooks/multiple_pw_conv.ipynb b/notebooks/multiple_pw_conv.ipynb index 8450212..b848859 100644 --- a/notebooks/multiple_pw_conv.ipynb +++ b/notebooks/multiple_pw_conv.ipynb @@ -169,7 +169,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ase import Atoms\n", + "from ase_koopmans import Atoms\n", "import copy\n", "\n", "H2O = Atoms(atoms,\n", diff --git a/quick_start.ipynb b/quick_start.ipynb index 26ea5e4..c59097f 100644 --- a/quick_start.ipynb +++ b/quick_start.ipynb @@ -201,7 +201,7 @@ "metadata": {}, "outputs": [], "source": [ - "from ase.io import read\n", + "from ase_koopmans.io import read\n", "\n", "ozone = read(\"/home/jovyan/work/koopmans_calcs/tutorial_1/ozon.xsf\")\n", "ozone.cell = [[14.1738, 0.0, 0.0],\n", diff --git a/src/aiida_koopmans/engine/aiida.py b/src/aiida_koopmans/engine/aiida.py index 52adf6f..613a989 100644 --- a/src/aiida_koopmans/engine/aiida.py +++ b/src/aiida_koopmans/engine/aiida.py @@ -2,6 +2,7 @@ from koopmans.step import Step from koopmans.calculators import Calc from koopmans.pseudopotentials import read_pseudo_file +from koopmans.status import Status from aiida.engine import run_get_node, submit @@ -23,7 +24,6 @@ class AiiDAEngine(Engine): step_data = {calc.directory: {'workchain': workchain, 'remote_folder': remote_folder}} and any other info we need for AiiDA. """ - def __init__(self, *args, **kwargs): self.blocking = kwargs.pop('blocking', True) self.step_data = { # TODO: change to a better name @@ -43,77 +43,97 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + + def run(self, step: Step): + + self.get_status(step) + if step.prefix in ['wannier90_preproc', 'pw2wannier90']: + self.set_status(step, Status.COMPLETED) + #self._step_completed_message(step) + return + + self.step_data['steps'][step.uid] = {} # maybe not needed + builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails + running = submit(builder) + # running = aiidawrapperwchain.submit(builder) # in the non-blocking case. + self.step_data['steps'][step.uid] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder} + + self.set_status(step, Status.RUNNING) + + return - - def _run_steps(self, steps: tuple[Step, ...]) -> None: + def load_step_data(self): try: with open('step_data.pkl', 'rb') as f: - self.step_data = pickle.load(f) # this will overwrite the step_data[configuration], ie. if we change codes or res we will not see it if the file already exists. - except: + # this will overwrite the step_data[configuration], + # i.e. if we change codes or res we will not see it if + # the file already exists. + self.step_data = pickle.load(f) + except FileNotFoundError: pass - self.from_scratch = False - for step in steps: - # self._step_running_message(step) - if step.directory in self.step_data['steps']: - continue - elif step.prefix in ['wannier90_preproc', 'pw2wannier90']: - print(f'skipping {step.prefix} step') - continue - else: - self.from_scratch = True - - #step.run() - self.step_data['steps'][step.directory] = {} - builder = get_builder_from_ase(calculator=step, step_data=self.step_data) # ASE to AiiDA conversion. put some error message if the conversion fails - running = submit(builder) - # running = aiidawrapperwchain.submit(builder) # in the non-blocking case. - self.step_data['steps'][step.directory] = {'workchain': running.pk, } #'remote_folder': running.outputs.remote_folder} - - #if self.from_scratch: + def dump_step_data(self): with open('step_data.pkl', 'wb') as f: - pickle.dump(self.step_data, f) + pickle.dump(self.step_data, f) + + def get_status(self, step: Step) -> Status: + return self.get_status_by_uid(step.uid) - if not self.blocking and self.from_scratch: - raise CalculationSubmitted("Calculation submitted to AiiDA, non blocking") - elif self.blocking: - for step in self.step_data['steps'].values(): - while not orm.load_node(step['workchain']).is_finished: - time.sleep(5) - - for step in steps: + + def get_status_by_uid(self, uid: str) -> Status: + self.load_step_data() + if uid not in self.step_data['steps']: + self.step_data['steps'][uid] = {'status': Status.NOT_STARTED} + return self.step_data['steps'][uid]['status'] + + def set_status(self, step: Step, status: Status): + self.set_status_by_uid(step.uid, status) + + def set_status_by_uid(self, uid: str, status: Status): + self.step_data['steps'][uid]['status'] = status + self.dump_step_data() + + def update_statuses(self) -> None: + time.sleep(5) + for uid in self.step_data['steps']: # convert from AiiDA to ASE results and populate ASE calculator - # TOBE put in a separate function - if step.prefix in ['wannier90_preproc', 'pw2wannier90']: + if not self.get_status_by_uid(uid) == Status.RUNNING: continue - workchain = orm.load_node(self.step_data['steps'][step.directory]['workchain']) - if "remote_folder" in workchain.outputs: - self.step_data['steps'][step.directory]['remote_folder'] = workchain.outputs.remote_folder.pk - output = None - if step.ext_out == ".wout": - output = read_output_file(step, workchain.outputs.wannier90.retrieved) - elif step.ext_out in [".pwo",".kho"]: - output = read_output_file(step, workchain.outputs.retrieved) - if hasattr(output.calc, 'kpts'): - step.kpts = output.calc.kpts - else: - output = read_output_file(step, workchain.outputs.retrieved) - if step.ext_out in [".pwo",".wout",".kso",".kho"]: - step.calc = output.calc - step.results = output.calc.results - step.generate_band_structure(nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) + + workchain = orm.load_node(self.step_data['steps'][uid]['workchain']) + if workchain.is_finished_ok: + self.set_status_by_uid(uid, Status.COMPLETED) + elif workchain.is_finished or workchain.is_excepted or workchain.is_killed: + self.set_status_by_uid(uid, Status.FAILED) + + return + + def load_results(self, step: Step) -> None: + + self.load_step_data() + workchain = orm.load_node(self.step_data['steps'][step.uid]['workchain']) + if "remote_folder" in workchain.outputs: + self.step_data['steps'][step.uid]['remote_folder'] = workchain.outputs.remote_folder.pk + output = None + if step.ext_out == ".wout": + output = read_output_file(step, workchain.outputs.wannier90.retrieved) + elif step.ext_out in [".pwo",".kho"]: + output = read_output_file(step, workchain.outputs.retrieved) + if hasattr(output.calc, 'kpts'): + step.kpts = output.calc.kpts + else: + output = read_output_file(step, workchain.outputs.retrieved) + if step.ext_out in [".pwo",".wout",".kso",".kho"]: + step.calc = output.calc + step.results = output.calc.results + step.generate_band_structure() #nelec=int(workchain.outputs.output_parameters.get_dict()['number_of_electrons'])) + self._step_completed_message(step) - - # If we reached here, all future steps should be performed from scratch - self.from_scratch = True - - # dump again to have update the information - with open('step_data.pkl', 'wb') as f: - pickle.dump(self.step_data, f) + + self.dump_step_data() - return - + def load_old_calculator(self, calc: Calc): raise NotImplementedError # load_old_calculator(calc) @@ -135,30 +155,3 @@ def get_pseudo_data(self, workflow): pseudo_data[pseudo[0].attributes.element] = read_pseudo_file(temp_file) return pseudo_data - - -def load_old_calculator(calc): - # This is a separate function so that it can be imported by other engines - loaded_calc = calc.__class__.fromfile(calc.directory / calc.prefix) - - if loaded_calc.is_complete(): - # If it is complete, load the results - calc.results = loaded_calc.results - - # Check the convergence of the calculation - calc.check_convergence() - - # Load k-points if relevant - if hasattr(loaded_calc, 'kpts'): - calc.kpts = loaded_calc.kpts - - if isinstance(calc, ReturnsBandStructure): - calc.generate_band_structure() - - if isinstance(calc, ProjwfcCalculator): - calc.generate_dos() - - if isinstance(calc, PhCalculator): - calc.read_dynG() - - return loaded_calc \ No newline at end of file diff --git a/src/aiida_koopmans/helpers.py b/src/aiida_koopmans/helpers.py index e7c1573..be9207b 100644 --- a/src/aiida_koopmans/helpers.py +++ b/src/aiida_koopmans/helpers.py @@ -19,8 +19,8 @@ from aiida.orm import Code, Computer from aiida_quantumespresso.calculations.pw import PwCalculation from aiida_wannier90.calculations.wannier90 import Wannier90Calculation -from ase import io -from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys +from ase_koopmans import io +from ase_koopmans.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys from aiida_koopmans.calculations.kcw import KcwCalculation from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files diff --git a/src/aiida_koopmans/parsers/kcw.py b/src/aiida_koopmans/parsers/kcw.py index 1305108..884f5f1 100644 --- a/src/aiida_koopmans/parsers/kcw.py +++ b/src/aiida_koopmans/parsers/kcw.py @@ -2,7 +2,7 @@ from pathlib import Path import pathlib import tempfile -from ase import io +from ase_koopmans import io from aiida.orm import Dict diff --git a/src/aiida_koopmans/utils.py b/src/aiida_koopmans/utils.py index c259b2d..8f83ad1 100644 --- a/src/aiida_koopmans/utils.py +++ b/src/aiida_koopmans/utils.py @@ -10,8 +10,11 @@ from aiida.orm import Code, Computer from aiida_quantumespresso.calculations.pw import PwCalculation from aiida_wannier90.calculations.wannier90 import Wannier90Calculation -from ase import io -from ase.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys + +from ase import Atoms +from ase_koopmans import Atoms as AtomsKoopmans +from ase_koopmans import io +from ase_koopmans.io.espresso import kch_keys, kcp_keys, kcs_keys, pw_keys, w2kcw_keys from aiida_koopmans.calculations.kcw import KcwCalculation from aiida_koopmans.data.utils import generate_singlefiledata, generate_alpha_singlefiledata, produce_wannier90_files @@ -54,7 +57,12 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): """ aiida_inputs = step_data['configuration'] calc_params = pw_calculator._parameters - structure = orm.StructureData(ase=pw_calculator.atoms) # TODO: only one sdata, stored in the step_data dict. but some cases have output structure diff from input. + + if isinstance(pw_calculator.atoms, AtomsKoopmans): + ase_atoms = Atoms.fromdict(pw_calculator.atoms.todict()) + + # WE NEED TO USE THE INPUT STRUCTURE OF SCF, WHEN WE DO NSCF + structure = orm.StructureData(ase=ase_atoms) # TODO: only one sdata, stored in the step_data dict. but some cases have output structure diff from input. pw_overrides = { "CONTROL": {}, @@ -93,7 +101,7 @@ def get_PwBaseWorkChain_from_ase(pw_calculator, step_data=None): # here we need explicit kpoints builder.kpoints.set_kpoints(calc_params["kpts"].kpts,cartesian=False) # TODO: check cartesian false is correct. - parent_calculators = [f[0].directory for f in pw_calculator.linked_files.values() if f[0] is not None] + parent_calculators = [f[0].uid for f in pw_calculator.linked_files.values() if f[0] is not None] if len(set(parent_calculators)) > 1: raise ValueError("More than one parent calculator found.") elif len(set(parent_calculators)) == 1: