Skip to content

Commit

Permalink
fixing all the decorators to avoid infinte recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
mikibonacci committed May 22, 2024
1 parent a9f6853 commit f6a4bd6
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions src/aiida_koopmans/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@

from aiida_koopmans.calculations.kcw import KcwCalculation

"""
ASE calculator MUST have `wchain` attribute (the related AiiDA WorkChain) to be able to use these functions!
"""

LOCALHOST_NAME = "localhost-test"
KCW_BLOCKED_KEYWORDS = [t[1] for t in KcwCalculation._blocked_keywords]
PW_BLOCKED_KEYWORDS = [t[1] for t in PwCalculation._blocked_keywords]
Expand Down Expand Up @@ -571,7 +575,7 @@ def aiida_pre_calculate_trigger(_pre_calculate):
@functools.wraps(_pre_calculate)
def wrapper_aiida_trigger(self):
if self.parameters.mode == "ase":
return self._pre_calculate()
return _pre_calculate(self,)
else:
pass
return wrapper_aiida_trigger
Expand All @@ -581,7 +585,7 @@ def aiida_calculate_trigger(_calculate):
@functools.wraps(_calculate)
def wrapper_aiida_trigger(self):
if self.parameters.mode == "ase":
return self._calculate()
return _calculate(self,)
else:
builder = mapping_calculators[self.ext_out](self)
from aiida.engine import run_get_node, submit
Expand All @@ -595,7 +599,7 @@ def aiida_post_calculate_trigger(_post_calculate):
@functools.wraps(_post_calculate)
def wrapper_aiida_trigger(self):
if self.parameters.mode == "ase":
return self._post_calculate()
return _post_calculate(self,)
else:
pass
return wrapper_aiida_trigger
Expand All @@ -606,24 +610,26 @@ def aiida_read_results_trigger(read_results):
@functools.wraps(read_results)
def wrapper_aiida_trigger(self):
if self.parameters.mode == "ase":
return self.read_results()
return read_results(self,)
else:
if self.ext_out == ".wout":
output = read_output_file(self, self.wchain.outputs.wannier90.retrieved)
elif self.ext_out == ".pwo":
output = read_output_file(self)
self.calc = output.calc
self.results = output.calc.results
if hasattr(output.calc, 'kpts'):
self.kpts = output.calc.kpts

self.calc = output.calc
self.results = output.calc.results

return wrapper_aiida_trigger

def aiida_link_trigger(link):
# This wraps the link method of Workflow class.
@functools.wraps(link)
def wrapper_aiida_trigger(self,src_calc, src_path, dest_calc, dest_path):
if self.parameters.mode == "ase":
return self.link(src_calc, src_path, dest_calc, dest_path)
return link(self, src_calc, src_path, dest_calc, dest_path)
elif src_calc: # if pseudo linking, src_calc = None
dest_calc.parent_folder = src_calc.wchain.outputs.remote_folder
return wrapper_aiida_trigger

0 comments on commit f6a4bd6

Please sign in to comment.