From 2c6734320b121e2cfb8f610b6201241eef6d9ee4 Mon Sep 17 00:00:00 2001 From: Jusong Yu Date: Wed, 11 Oct 2023 16:35:45 +0200 Subject: [PATCH] Fix/209/dump convergence with healthy check (#228) * healthy check for convergence test * convergence test plot to output --- aiida_sssp_workflow/cli/inspect.py | 241 ++++++++++++++++++++++ aiida_sssp_workflow/protocol/criteria.yml | 10 + 2 files changed, 251 insertions(+) diff --git a/aiida_sssp_workflow/cli/inspect.py b/aiida_sssp_workflow/cli/inspect.py index a44ac1ae..65e13a89 100644 --- a/aiida_sssp_workflow/cli/inspect.py +++ b/aiida_sssp_workflow/cli/inspect.py @@ -7,8 +7,10 @@ import matplotlib.pyplot as plt import numpy as np from aiida import orm +from aiida.plugins import WorkflowFactory from aiida_sssp_workflow.cli import cmd_root +from aiida_sssp_workflow.utils import HIGH_DUAL_ELEMENTS, get_protocol def birch_murnaghan(V, E0, V0, B0, B01): @@ -80,6 +82,53 @@ def eos_plot( ax.set_title(title, fontsize=fontsize) +def convergence_plot( + ax, + convergence_data, + converged_xy, + y_thresholds_range, + y_label, + title, +): + xs = convergence_data["xs"] + ys = convergence_data["ys"] + + # xlim a bit larger than the range of xs + xlims = (min(xs) * 0.9, max(xs) * 1.02) + + ax.plot(xs, ys, "-x") + ax.scatter( + *converged_xy, + marker="s", + s=100, + label=f"Converge at {converged_xy[0]} Ry", + facecolors="none", + edgecolors="red", + ) + ax.fill_between( + x=xlims, + y1=y_thresholds_range[0], + y2=y_thresholds_range[1], + alpha=0.3, + color="green", + ) + ax.legend(loc="upper right", fontsize=8) + + # twice the range of ylimits if the y_thresholds_range just cover the y range + if y_thresholds_range[1] > max(ys) or y_thresholds_range[1] < max(ys) * 4: + y_max = y_thresholds_range[1] * 2 + y_min = -0.05 * y_max + ax.set_ylim(bottom=y_min, top=y_max) + + # change ticks size + ax.tick_params(axis="both", labelsize=6) + + ax.set_xlim(*xlims) + ax.set_ylabel(y_label, fontsize=8) + ax.set_xlabel("Cutoff (Ry)", fontsize=8) + ax.set_title(title, fontsize=8) + + @cmd_root.command("inspect") @click.argument("node") # The node to inspect, uuid or pk @click.option("--output", "-o", default="output", help="The output file name") @@ -142,6 +191,198 @@ def inspect(node, output): fig.tight_layout() fig.savefig(f"{output}_precision.pdf", bbox_inches="tight") + if "convergence" in wf_node.outputs: + convergence_summary = {} + + # A4 canvas for plot + # landscape mode, shoulder to shoulder for ecutwfc and ecutrho for each property + # five rows for five properties + rows = len(wf_node.outputs.convergence) + fig, axs = plt.subplots( + rows, + 2, + figsize=(11.69, 8.27), + dpi=100, + gridspec_kw={"width_ratios": [3, 1]}, + ) + subplot_index = 0 + + for property in [ + "bands", + "cohesive_energy", + "pressure", + "delta", + "phonon_frequencies", + ]: + # print summary of the convergence to a json file + convergence = wf_node.outputs.convergence[property] + + cutoff_control_protocol = wf_node.inputs.convergence.cutoff_control.value + cutoff_control = get_protocol("control", name=cutoff_control_protocol) + wfc_scan = cutoff_control["wfc_scan"] + # See if all the scans are finished by compare the list with the list of control protocol + ecutwfc_list = convergence.output_parameters_wfc_test.get_dict().get( + "ecutwfc" + ) + wfc_scan_healthy = len(ecutwfc_list) / len(wfc_scan) + + # if only first two scan values are not in the list, it is still regarted as 100% healthy + # Since it is because the first two ecutwfc values are too small for some elements + if wfc_scan_healthy != 1 and ( + wfc_scan[0] not in ecutwfc_list or wfc_scan[1] not in ecutwfc_list + ): + wfc_scan_healthy = 1 + + ecutrho_test_list = convergence.output_parameters_rho_test.get_dict().get( + "ecutrho" + ) + + element = wf_node.outputs.pseudo_info.get_dict().get("element") + pp_type = wf_node.outputs.pseudo_info.get_dict().get("pp_type") + if pp_type in ["nc", "sl"]: + expected_len_dual_scan = cutoff_control["nc_dual_scan"] + else: + if element in HIGH_DUAL_ELEMENTS: + expected_len_dual_scan = cutoff_control["nonnc_high_dual_scan"] + else: + expected_len_dual_scan = cutoff_control["nonnc_dual_scan"] + + # minus one for the reference value + rho_scan_healthy = (len(ecutrho_test_list) - 1) / len( + expected_len_dual_scan + ) + + color = "red" if wfc_scan_healthy != 1 or rho_scan_healthy != 1 else "green" + click.secho( + f"Convergence scan healthy check for {property}: wavefunction scan = {round(wfc_scan_healthy*100, 2)}%, charge density scan = {round(rho_scan_healthy*100, 2)}%", + fg=color, + ) + + # print summary of the convergence to a json file + # be careful the key for charge density is "chargedensity_cutoff" instead of "charge_density_cutoff + property_summary = convergence.output_parameters.get_dict() + property_summary["wfc_scan_healthy"] = wfc_scan_healthy + property_summary["rho_scan_healthy"] = rho_scan_healthy + + convergence_summary[property] = property_summary + + # plot to the ax + # ax1 on the left for ecutwfc + # ax2 on the right for ecutrho + # the ratio of the width is 3:1 + ax1 = axs.flat[subplot_index] + ax2 = axs.flat[subplot_index + 1] + + # data preparation + # Will only plot the measured properties e.g. for bands it is the eta_c + _ConvergenceWorkChain = WorkflowFactory( + f"sssp_workflow.convergence.{property}" + ) + measured_key = _ConvergenceWorkChain._MEASURE_OUT_PROPERTY + + used_criteria = convergence.output_parameters.get_dict().get( + "used_criteria" + ) + crieria_protocol = get_protocol("criteria", name=used_criteria) + y_thresholds_range = crieria_protocol[property]["bounds"] + y_unit = crieria_protocol[property]["unit"] + # use greek letter delta + y_label = f"Δ {y_unit}" + + conv_data = {} + conv_data["xs"] = convergence.output_parameters_wfc_test.get_dict().get( + "ecutwfc" + ) + conv_data["ys"] = convergence.output_parameters_wfc_test.get_dict().get( + measured_key + ) + + _x = convergence.output_parameters.get_dict().get("wavefunction_cutoff") + _y = dict(zip(conv_data["xs"], conv_data["ys"])).get(_x) + converged_xy = (_x, _y) + + _max_ecutwfc = conv_data["xs"][-1] + _max_ecutrho = convergence.output_parameters_rho_test.get_dict().get( + "ecutrho" + )[-1] + dual = round(_max_ecutrho / _max_ecutwfc, 1) + + property_name = property.replace("_", " ").capitalize() + + title = f"{property_name} convergence wrt wavefunction cutoff (at charge density cutoff = wavefunction cutoff * {dual} Ry)" + + convergence_plot( + ax1, + conv_data, + converged_xy, + y_thresholds_range, + y_label=y_label, + title=title, + ) + + # data preparation for ecutrho + conv_data = {} + conv_data["xs"] = convergence.output_parameters_rho_test.get_dict().get( + "ecutrho" + ) + conv_data["ys"] = convergence.output_parameters_rho_test.get_dict().get( + measured_key + ) + + ecutwfc = convergence.output_parameters.get_dict().get( + "wavefunction_cutoff" + ) + + _x = convergence.output_parameters.get_dict().get("chargedensity_cutoff") + _y = dict(zip(conv_data["xs"], conv_data["ys"])).get(_x) + converged_xy = (_x, _y) + + title = f"charge density cutoff (at wavefunction cutoff {ecutwfc} Ry)" + + convergence_plot( + ax2, + conv_data, + converged_xy, + y_thresholds_range, + y_label=y_label, + title=title, + ) + + # jump to the next row + subplot_index += 2 + + # calculate the recommended cutoffs from the maximum of all properties scan + recommended_ecutwfc = 0 + recommended_ecutrho = 0 + for value in convergence_summary.values(): + recommended_ecutwfc = max(recommended_ecutwfc, value["wavefunction_cutoff"]) + recommended_ecutrho = max( + recommended_ecutrho, value["chargedensity_cutoff"] + ) + + convergence_summary["recommended_cutoffs"] = { + "wavefunction_cutoff": recommended_ecutwfc, + "chargedensity_cutoff": recommended_ecutrho, + } + + try: + d_str = json.dumps(convergence_summary, indent=4) + with open(f"{output}_convergence_summary.json", "w") as f: + f.write(d_str) + except: + pass + + # fig to pdf + psp_label = wf_node.base.extras.all["label"].split(" ")[-1] + criteria = wf_node.inputs.convergence.criteria.value + fig.tight_layout() + fig.suptitle( + f"Convergence verification for {psp_label} under {criteria} creteria", + fontsize=10, + ) + fig.subplots_adjust(top=0.92) + fig.savefig(f"{output}_convergence.pdf", bbox_inches="tight") + if __name__ == "__main__": inspect() diff --git a/aiida_sssp_workflow/protocol/criteria.yml b/aiida_sssp_workflow/protocol/criteria.yml index 7597e6dc..b24ab73a 100644 --- a/aiida_sssp_workflow/protocol/criteria.yml +++ b/aiida_sssp_workflow/protocol/criteria.yml @@ -7,26 +7,31 @@ efficiency: mode: 0 bounds: [0.0, 2.0] # when relative error < 2.0 meV/atom eps: 1.0e-3 + unit: meV/atom delta: mode: 0 bounds: [0.0, 0.2] # when absolute error < 0.2 meV/atom eps: 1.0e-3 + unit: meV/atom phonon_frequencies: mode: 0 bounds: [0.0, 2.0] # when relative error < 2.0% eps: 1.0e-3 + unit: "%" pressure: mode: 0 bounds: [0.0, 1.0] # when relative error < 1.0% eps: 1.0e-3 + unit: "%" bands: mode: 0 bounds: [0.0, 20] # when error eta_c < 20 meV eps: 1.0e-3 + unit: meV/atom precision: @@ -37,23 +42,28 @@ precision: mode: 0 bounds: [0.0, 2.0] # when relative error < 2.0 meV/atom eps: 1.0e-3 + unit: meV/atom delta: mode: 0 bounds: [0.0, 0.1] # when absolute error < 0.1 meV/atom eps: 1.0e-3 + unit: meV/atom phonon_frequencies: mode: 0 bounds: [0.0, 1.0] # when error < 1.0 % eps: 1.0e-3 + unit: "%" pressure: mode: 0 bounds: [0.0, 0.5] # when relative error < 0.5% eps: 1.0e-3 + unit: "%" bands: mode: 0 bounds: [0.0, 15] # when error eta_c < 15 meV eps: 1.0e-3 + unit: meV/atom