Skip to content

Commit

Permalink
Merge branch 'fixing/compatibility' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Miki Bonacci committed Feb 19, 2024
2 parents 37598ca + 7abce1a commit e9b27b6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 12 deletions.
34 changes: 23 additions & 11 deletions aiida_yambo_wannier90/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from aiida_wannier90_workflows.utils.kpoints import (
get_explicit_kpoints,
get_mesh_from_kpoints,
get_path_from_kpoints
)
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_kpoints
from aiida_wannier90_workflows.workflows import (
Expand Down Expand Up @@ -626,9 +627,7 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements
"""Initialize context variables."""

self.ctx.current_structure = self.inputs.structure

if "bands_kpoints" in self.inputs:
self.ctx.bands_kpoints = self.inputs.bands_kpoints


# Converged mesh from YamboConvergence
self.ctx.kpoints_gw_conv = None
Expand Down Expand Up @@ -676,7 +675,13 @@ def setup(self) -> None: # pylint: disable=inconsistent-return-statements

def should_run_seekpath(self):
"""Run seekpath if the `inputs.bands_kpoints` is not provided."""
return "bands_kpoints" not in self.inputs
if "bands_kpoints" in self.inputs:
self.ctx.current_kpoint_path = get_path_from_kpoints(
self.inputs["bands_kpoints"]
)
return False
else:
return True

def run_seekpath(self):
"""Run the structure through SeeKpath to get the primitive and normalized structure."""
Expand All @@ -692,7 +697,11 @@ def run_seekpath(self):

self.ctx.current_structure = result["primitive_structure"]

self.ctx.current_bands_kpoints = result["explicit_kpoints"]
# Add `kpoint_path` for Wannier bands
self.ctx.current_kpoint_path = get_path_from_kpoints(
result["explicit_kpoints"]
)


structure_formula = self.inputs.structure.get_formula()
primitive_structure_formula = result["primitive_structure"].get_formula()
Expand Down Expand Up @@ -1056,11 +1065,12 @@ def prepare_wannier90_pp_inputs(self) -> AttributeDict:

inputs.wannier90.structure = self.ctx.current_structure

#params = inputs.wannier90.parameters.get_dict()
#params["bands_plot"] = False
#inputs.wannier90.parameters = orm.Dict(params)
params = inputs.wannier90.parameters.get_dict()
params["bands_plot"] = False
inputs.wannier90.parameters = orm.Dict(params)

inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
if self.ctx.current_kpoint_path:
inputs.wannier90.kpoint_path = self.ctx.current_kpoint_path

# Use commensurate kmesh
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
Expand Down Expand Up @@ -1172,7 +1182,8 @@ def prepare_wannier90_inputs(self) -> AttributeDict:
)

inputs.structure = self.ctx.current_structure
inputs.bands_kpoints = self.ctx.current_bands_kpoints
if self.ctx.current_kpoint_path:
inputs.wannier90.wannier90.kpoint_path = self.ctx.current_kpoint_path

# Use commensurate kmesh
if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
Expand Down Expand Up @@ -1258,7 +1269,8 @@ def prepare_wannier90_qp_inputs(self) -> AttributeDict:
)

inputs.wannier90.structure = self.ctx.current_structure
inputs.wannier90.bands_kpoints = self.ctx.current_bands_kpoints
if self.ctx.current_kpoint_path:
inputs.kpoint_path = self.ctx.current_kpoint_path

if self.ctx.kpoints_w90_input != self.ctx.kpoints_w90:
set_kpoints(
Expand Down
2 changes: 1 addition & 1 deletion examples/example_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aiida_wannier90_workflows.cli.params import RUN
from aiida_wannier90_workflows.utils.workflows.builder.serializer import print_builder
from aiida_wannier90_workflows.utils.kpoints import get_explicit_kpoints_from_mesh
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_parallelization, set_num_bands, set_kpoints
from aiida_wannier90_workflows.utils.workflows.builder.setter import set_parallelization, set_num_bands
from aiida_wannier90_workflows.utils.workflows.builder.submit import submit_and_add_group
from aiida_wannier90_workflows.common.types import WannierProjectionType
from aiida_wannier90_workflows.workflows import Wannier90BandsWorkChain
Expand Down

0 comments on commit e9b27b6

Please sign in to comment.