Skip to content

Commit

Permalink
Fix/209/dump convergence with healthy check (#228)
Browse files Browse the repository at this point in the history
* healthy check for convergence test

* convergence test plot to output
  • Loading branch information
unkcpz authored Oct 11, 2023
1 parent b8ab488 commit 2c67343
Show file tree
Hide file tree
Showing 2 changed files with 251 additions and 0 deletions.
241 changes: 241 additions & 0 deletions aiida_sssp_workflow/cli/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import matplotlib.pyplot as plt
import numpy as np
from aiida import orm
from aiida.plugins import WorkflowFactory

from aiida_sssp_workflow.cli import cmd_root
from aiida_sssp_workflow.utils import HIGH_DUAL_ELEMENTS, get_protocol


def birch_murnaghan(V, E0, V0, B0, B01):
Expand Down Expand Up @@ -80,6 +82,53 @@ def eos_plot(
ax.set_title(title, fontsize=fontsize)


def convergence_plot(
ax,
convergence_data,
converged_xy,
y_thresholds_range,
y_label,
title,
):
xs = convergence_data["xs"]
ys = convergence_data["ys"]

# xlim a bit larger than the range of xs
xlims = (min(xs) * 0.9, max(xs) * 1.02)

ax.plot(xs, ys, "-x")
ax.scatter(
*converged_xy,
marker="s",
s=100,
label=f"Converge at {converged_xy[0]} Ry",
facecolors="none",
edgecolors="red",
)
ax.fill_between(
x=xlims,
y1=y_thresholds_range[0],
y2=y_thresholds_range[1],
alpha=0.3,
color="green",
)
ax.legend(loc="upper right", fontsize=8)

# twice the range of ylimits if the y_thresholds_range just cover the y range
if y_thresholds_range[1] > max(ys) or y_thresholds_range[1] < max(ys) * 4:
y_max = y_thresholds_range[1] * 2
y_min = -0.05 * y_max
ax.set_ylim(bottom=y_min, top=y_max)

# change ticks size
ax.tick_params(axis="both", labelsize=6)

ax.set_xlim(*xlims)
ax.set_ylabel(y_label, fontsize=8)
ax.set_xlabel("Cutoff (Ry)", fontsize=8)
ax.set_title(title, fontsize=8)


@cmd_root.command("inspect")
@click.argument("node") # The node to inspect, uuid or pk
@click.option("--output", "-o", default="output", help="The output file name")
Expand Down Expand Up @@ -142,6 +191,198 @@ def inspect(node, output):
fig.tight_layout()
fig.savefig(f"{output}_precision.pdf", bbox_inches="tight")

if "convergence" in wf_node.outputs:
convergence_summary = {}

# A4 canvas for plot
# landscape mode, shoulder to shoulder for ecutwfc and ecutrho for each property
# five rows for five properties
rows = len(wf_node.outputs.convergence)
fig, axs = plt.subplots(
rows,
2,
figsize=(11.69, 8.27),
dpi=100,
gridspec_kw={"width_ratios": [3, 1]},
)
subplot_index = 0

for property in [
"bands",
"cohesive_energy",
"pressure",
"delta",
"phonon_frequencies",
]:
# print summary of the convergence to a json file
convergence = wf_node.outputs.convergence[property]

cutoff_control_protocol = wf_node.inputs.convergence.cutoff_control.value
cutoff_control = get_protocol("control", name=cutoff_control_protocol)
wfc_scan = cutoff_control["wfc_scan"]
# See if all the scans are finished by compare the list with the list of control protocol
ecutwfc_list = convergence.output_parameters_wfc_test.get_dict().get(
"ecutwfc"
)
wfc_scan_healthy = len(ecutwfc_list) / len(wfc_scan)

# if only first two scan values are not in the list, it is still regarted as 100% healthy
# Since it is because the first two ecutwfc values are too small for some elements
if wfc_scan_healthy != 1 and (
wfc_scan[0] not in ecutwfc_list or wfc_scan[1] not in ecutwfc_list
):
wfc_scan_healthy = 1

ecutrho_test_list = convergence.output_parameters_rho_test.get_dict().get(
"ecutrho"
)

element = wf_node.outputs.pseudo_info.get_dict().get("element")
pp_type = wf_node.outputs.pseudo_info.get_dict().get("pp_type")
if pp_type in ["nc", "sl"]:
expected_len_dual_scan = cutoff_control["nc_dual_scan"]
else:
if element in HIGH_DUAL_ELEMENTS:
expected_len_dual_scan = cutoff_control["nonnc_high_dual_scan"]
else:
expected_len_dual_scan = cutoff_control["nonnc_dual_scan"]

# minus one for the reference value
rho_scan_healthy = (len(ecutrho_test_list) - 1) / len(
expected_len_dual_scan
)

color = "red" if wfc_scan_healthy != 1 or rho_scan_healthy != 1 else "green"
click.secho(
f"Convergence scan healthy check for {property}: wavefunction scan = {round(wfc_scan_healthy*100, 2)}%, charge density scan = {round(rho_scan_healthy*100, 2)}%",
fg=color,
)

# print summary of the convergence to a json file
# be careful the key for charge density is "chargedensity_cutoff" instead of "charge_density_cutoff
property_summary = convergence.output_parameters.get_dict()
property_summary["wfc_scan_healthy"] = wfc_scan_healthy
property_summary["rho_scan_healthy"] = rho_scan_healthy

convergence_summary[property] = property_summary

# plot to the ax
# ax1 on the left for ecutwfc
# ax2 on the right for ecutrho
# the ratio of the width is 3:1
ax1 = axs.flat[subplot_index]
ax2 = axs.flat[subplot_index + 1]

# data preparation
# Will only plot the measured properties e.g. for bands it is the eta_c
_ConvergenceWorkChain = WorkflowFactory(
f"sssp_workflow.convergence.{property}"
)
measured_key = _ConvergenceWorkChain._MEASURE_OUT_PROPERTY

used_criteria = convergence.output_parameters.get_dict().get(
"used_criteria"
)
crieria_protocol = get_protocol("criteria", name=used_criteria)
y_thresholds_range = crieria_protocol[property]["bounds"]
y_unit = crieria_protocol[property]["unit"]
# use greek letter delta
y_label = f"Δ {y_unit}"

conv_data = {}
conv_data["xs"] = convergence.output_parameters_wfc_test.get_dict().get(
"ecutwfc"
)
conv_data["ys"] = convergence.output_parameters_wfc_test.get_dict().get(
measured_key
)

_x = convergence.output_parameters.get_dict().get("wavefunction_cutoff")
_y = dict(zip(conv_data["xs"], conv_data["ys"])).get(_x)
converged_xy = (_x, _y)

_max_ecutwfc = conv_data["xs"][-1]
_max_ecutrho = convergence.output_parameters_rho_test.get_dict().get(
"ecutrho"
)[-1]
dual = round(_max_ecutrho / _max_ecutwfc, 1)

property_name = property.replace("_", " ").capitalize()

title = f"{property_name} convergence wrt wavefunction cutoff (at charge density cutoff = wavefunction cutoff * {dual} Ry)"

convergence_plot(
ax1,
conv_data,
converged_xy,
y_thresholds_range,
y_label=y_label,
title=title,
)

# data preparation for ecutrho
conv_data = {}
conv_data["xs"] = convergence.output_parameters_rho_test.get_dict().get(
"ecutrho"
)
conv_data["ys"] = convergence.output_parameters_rho_test.get_dict().get(
measured_key
)

ecutwfc = convergence.output_parameters.get_dict().get(
"wavefunction_cutoff"
)

_x = convergence.output_parameters.get_dict().get("chargedensity_cutoff")
_y = dict(zip(conv_data["xs"], conv_data["ys"])).get(_x)
converged_xy = (_x, _y)

title = f"charge density cutoff (at wavefunction cutoff {ecutwfc} Ry)"

convergence_plot(
ax2,
conv_data,
converged_xy,
y_thresholds_range,
y_label=y_label,
title=title,
)

# jump to the next row
subplot_index += 2

# calculate the recommended cutoffs from the maximum of all properties scan
recommended_ecutwfc = 0
recommended_ecutrho = 0
for value in convergence_summary.values():
recommended_ecutwfc = max(recommended_ecutwfc, value["wavefunction_cutoff"])
recommended_ecutrho = max(
recommended_ecutrho, value["chargedensity_cutoff"]
)

convergence_summary["recommended_cutoffs"] = {
"wavefunction_cutoff": recommended_ecutwfc,
"chargedensity_cutoff": recommended_ecutrho,
}

try:
d_str = json.dumps(convergence_summary, indent=4)
with open(f"{output}_convergence_summary.json", "w") as f:
f.write(d_str)
except:
pass

# fig to pdf
psp_label = wf_node.base.extras.all["label"].split(" ")[-1]
criteria = wf_node.inputs.convergence.criteria.value
fig.tight_layout()
fig.suptitle(
f"Convergence verification for {psp_label} under {criteria} creteria",
fontsize=10,
)
fig.subplots_adjust(top=0.92)
fig.savefig(f"{output}_convergence.pdf", bbox_inches="tight")


if __name__ == "__main__":
inspect()
10 changes: 10 additions & 0 deletions aiida_sssp_workflow/protocol/criteria.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,31 @@ efficiency:
mode: 0
bounds: [0.0, 2.0] # when relative error < 2.0 meV/atom
eps: 1.0e-3
unit: meV/atom

delta:
mode: 0
bounds: [0.0, 0.2] # when absolute error < 0.2 meV/atom
eps: 1.0e-3
unit: meV/atom

phonon_frequencies:
mode: 0
bounds: [0.0, 2.0] # when relative error < 2.0%
eps: 1.0e-3
unit: "%"

pressure:
mode: 0
bounds: [0.0, 1.0] # when relative error < 1.0%
eps: 1.0e-3
unit: "%"

bands:
mode: 0
bounds: [0.0, 20] # when error eta_c < 20 meV
eps: 1.0e-3
unit: meV/atom


precision:
Expand All @@ -37,23 +42,28 @@ precision:
mode: 0
bounds: [0.0, 2.0] # when relative error < 2.0 meV/atom
eps: 1.0e-3
unit: meV/atom

delta:
mode: 0
bounds: [0.0, 0.1] # when absolute error < 0.1 meV/atom
eps: 1.0e-3
unit: meV/atom

phonon_frequencies:
mode: 0
bounds: [0.0, 1.0] # when error < 1.0 %
eps: 1.0e-3
unit: "%"

pressure:
mode: 0
bounds: [0.0, 0.5] # when relative error < 0.5%
eps: 1.0e-3
unit: "%"

bands:
mode: 0
bounds: [0.0, 15] # when error eta_c < 15 meV
eps: 1.0e-3
unit: meV/atom

0 comments on commit 2c67343

Please sign in to comment.