diff --git a/pyproject.toml b/pyproject.toml index 72c78cf4..634ca24d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,7 @@ filterwarnings = [ 'ignore:Creating AiiDA configuration folder.*:UserWarning', 'ignore:Object of type .* not in session, .* operation along .* will not proceed:sqlalchemy.exc.SAWarning', 'ignore:The `Code` class is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning', + 'ignore:`CalcJobNode.*` is deprecated.*:aiida.common.warnings.AiidaDeprecationWarning', ] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", diff --git a/src/aiida_sssp_workflow/workflows/convergence/_base.py b/src/aiida_sssp_workflow/workflows/convergence/_base.py index a9322ff3..0863c739 100644 --- a/src/aiida_sssp_workflow/workflows/convergence/_base.py +++ b/src/aiida_sssp_workflow/workflows/convergence/_base.py @@ -53,8 +53,8 @@ def is_valid_convergence_configuration(value, _=None): def is_valid_cutoff_list(cutoff_list, _=None): """Check the cutoff list is a list of tuples and the cutoffs are increasing""" - if not all(isinstance(cutoff, tuple) for cutoff in cutoff_list): - return "cutoff_list must be a list of tuples" + if not all(isinstance(cutoff, (tuple, list)) for cutoff in cutoff_list): + return "cutoff_list must be a list of tuples or list." if not all( cutoff_list[i][0] < cutoff_list[i + 1][0] for i in range(len(cutoff_list) - 1) ): diff --git a/src/aiida_sssp_workflow/workflows/evaluate/_bands.py b/src/aiida_sssp_workflow/workflows/evaluate/_bands.py index 028e2b2e..f6210bcc 100644 --- a/src/aiida_sssp_workflow/workflows/evaluate/_bands.py +++ b/src/aiida_sssp_workflow/workflows/evaluate/_bands.py @@ -154,9 +154,6 @@ def inspect_bands(self): """ workchain = self.ctx.workchain_bands - if workchain.is_finished: - self._disable_cache(workchain) - if not workchain.is_finished_ok: self.logger.warning( f"PwBandsWorkChain for bands evaluation failed with exit status {workchain.exit_status}" diff --git a/src/aiida_sssp_workflow/workflows/evaluate/_caching_wise_bands.py b/src/aiida_sssp_workflow/workflows/evaluate/_caching_wise_bands.py index 7839f884..f927e25a 100644 --- a/src/aiida_sssp_workflow/workflows/evaluate/_caching_wise_bands.py +++ b/src/aiida_sssp_workflow/workflows/evaluate/_caching_wise_bands.py @@ -359,7 +359,7 @@ def inspect_bands(self): is_cleaned = self.ctx.current_folder.base.extras.get("cleaned", False) if is_cleaned: self.logger.warning( - f"PhBaseWorkChain failed because the remote folder is empty with exit status {workchain.exit_status}, invalid the caching of the node and re-run scf calculation." + f"PwBaseWorkChain failed because the remote folder is empty with exit status {workchain.exit_status}, invalid the caching of the node and re-run scf calculation." ) # invalid the caching of the node and re-run scf calculation workchain_scf = self.ctx.workchain_scf diff --git a/tests/workflows/convergence/test_caching.py b/tests/workflows/convergence/test_caching.py new file mode 100644 index 00000000..b8e7820b --- /dev/null +++ b/tests/workflows/convergence/test_caching.py @@ -0,0 +1,267 @@ +import pytest + +from aiida.plugins import DataFactory, WorkflowFactory +from aiida.engine import ProcessBuilder, run_get_node +from aiida.manage.caching import enable_caching + + +UpfData = DataFactory("pseudo.upf") + + +@pytest.mark.slow +@pytest.mark.usefixtures("aiida_profile_clean") +def test_caching_bands( + pseudo_path, + code_generator, +): + """Test caching is working for the first pw calculation of bands. + Test if the remote_path was empty bands and phonon_frequencies workflow will manage + to rerun the first preparing pw.x calculation.""" + _ConvergenceBandsWorkChain = WorkflowFactory("sssp_workflow.convergence.bands") + + # The caching should turned on also for the first prepareing run + # Otherwise, the scf calculation inside band workflow has two duplicate nodes which has same uuid + # but are both valid cached source. This cause caching race condition. + with enable_caching(identifier="aiida.calculations:quantumespresso.*"): + bands_builder: ProcessBuilder = _ConvergenceBandsWorkChain.get_builder( + pseudo=pseudo_path("Al"), + protocol="test", + cutoff_list=[(20, 80), (30, 120)], + configuration="DC", + code=code_generator("pw"), + clean_workdir=True, + ) + + # Running a bands convergence workflow first and check that SCF is not from cached calcjob + _, source_node = run_get_node(bands_builder) + + # check the first scf of reference + # The pw calculation + source_ref_wf = [ + p + for p in source_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + source_scf_calcjob_node = source_ref_wf.called[1].called[0].called[1] + assert ( + source_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) is None + ) + + # check the first bands of reference was cached + # The pw calculation + source_band_calcjob_node = source_ref_wf.called[1].called[1].called[0] + assert ( + source_band_calcjob_node.base.extras.get("_aiida_cached_from", None) is None + ) + + # Run again and check it is using caching + _, cached_node = run_get_node(bands_builder) + cached_ref_wf = [ + p + for p in cached_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + cached_scf_calcjob_node = cached_ref_wf.called[1].called[0].called[1] + + assert ( + cached_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) + == source_scf_calcjob_node.uuid + ) + assert not cached_scf_calcjob_node.base.caching.is_valid_cache + + cached_band_calcjob_node = cached_ref_wf.called[1].called[1].called[0] + + assert ( + cached_band_calcjob_node.base.extras.get("_aiida_cached_from", None) + == source_band_calcjob_node.uuid + ) + + +@pytest.mark.slow +@pytest.mark.usefixtures("aiida_profile_clean") +def test_caching_phonon_frequencies( + pseudo_path, + code_generator, +): + """Test caching is working for the first pw calculation of phonon_frequencies. + Test if the remote_path was empty phonon_frequencies workflow will manage + to rerun the first preparing pw.x calculation.""" + _ConvergencePhononFrequenciessWorkChain = WorkflowFactory( + "sssp_workflow.convergence.phonon_frequencies" + ) + + # The caching should turned on also for the first prepareing run + # Otherwise, the scf calculation inside band workflow has two duplicate nodes which has same uuid + # but are both valid cached source. This cause caching race condition. + with enable_caching(identifier="aiida.calculations:quantumespresso.*"): + phonon_frequencies_builder: ProcessBuilder = ( + _ConvergencePhononFrequenciessWorkChain.get_builder( + pseudo=pseudo_path("Al"), + protocol="test", + cutoff_list=[(20, 80), (30, 120)], + configuration="DC", + pw_code=code_generator("pw"), + ph_code=code_generator("ph"), + clean_workdir=True, + ) + ) + + # Running a phonon frequencies convergence workflow first and check that SCF is not from cached calcjob + _, source_node = run_get_node(phonon_frequencies_builder) + + # check the first scf of reference + # The pw calculation + source_ref_wf = [ + p + for p in source_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + source_scf_calcjob_node = source_ref_wf.called[0].called[1] + assert ( + source_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) is None + ) + + # The ph calculation + source_ph_calcjob_node = source_ref_wf.called[1].called[0] + assert ( + source_ph_calcjob_node.base.extras.get("_aiida_cached_from", None) is None + ) + + # Run again and check it is using caching + _, cached_node = run_get_node(phonon_frequencies_builder) + cached_ref_wf = [ + p + for p in cached_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + cached_scf_calcjob_node = cached_ref_wf.called[0].called[1] + + assert ( + cached_scf_calcjob_node.base.extras.get("_aiida_cached_from", None) + == source_scf_calcjob_node.uuid + ) + assert not cached_scf_calcjob_node.base.caching.is_valid_cache + + # Run again and check it is using caching + cached_ph_calcjob_node = cached_ref_wf.called[1].called[0] + assert ( + cached_ph_calcjob_node.base.extras.get("_aiida_cached_from", None) + == source_ph_calcjob_node.uuid + ) + + +@pytest.mark.slow +@pytest.mark.usefixtures("aiida_profile_clean") +def test_caching_bands_rerun_pw_prepare( + pseudo_path, + code_generator, +): + """Test caching is working for the first pw calculation of bands. + After the first run, I manually make the bands calculation invalid cache so it can rerun with an empty remote + The test check that the scf will rerun if the remote not exist.""" + _ConvergenceBandsWorkChain = WorkflowFactory("sssp_workflow.convergence.bands") + + # The caching should turned on also for the first prepareing run + # Otherwise, the scf calculation inside band workflow has two duplicate nodes which has same uuid + # but are both valid cached source. This cause caching race condition. + with enable_caching(identifier="aiida.calculations:quantumespresso.*"): + bands_builder: ProcessBuilder = _ConvergenceBandsWorkChain.get_builder( + pseudo=pseudo_path("Al"), + protocol="test", + cutoff_list=[(20, 80), (30, 120)], + configuration="DC", + code=code_generator("pw"), + clean_workdir=True, + ) + + # Running a bands convergence workflow first and check that SCF is not from cached calcjob + _, source_node = run_get_node(bands_builder) + + # Make the source band calculation invalid cache + source_ref_wf = [ + p + for p in source_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + source_band_calcjob_node = source_ref_wf.called[1].called[1].called[0] + assert source_band_calcjob_node.is_valid_cache + + source_band_calcjob_node.is_valid_cache = False + + # Run again and check it is using caching + _, cached_node = run_get_node(bands_builder) + + # Check the band work chain finished okay + cached_ref_wf = [ + p + for p in cached_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + assert cached_ref_wf.called[1].is_finished_ok + + cached_band_calcjob_node = cached_ref_wf.called[1].called[1].called[0] + assert ( + cached_band_calcjob_node.base.extras.get("_aiida_cached_from", None) is None + ) + assert cached_band_calcjob_node.exit_code.status == 305 + + +@pytest.mark.slow +@pytest.mark.usefixtures("aiida_profile_clean") +def test_caching_phonon_frequencies_rerun_pw_prepare( + pseudo_path, + code_generator, +): + """Test caching is working for the first pw calculation of phonon_frequencies. + Test if the remote_path was empty phonon_frequencies workflow will manage + to rerun the first preparing pw.x calculation.""" + _ConvergencePhononFrequenciessWorkChain = WorkflowFactory( + "sssp_workflow.convergence.phonon_frequencies" + ) + + # The caching should turned on also for the first prepareing run + # Otherwise, the scf calculation inside band workflow has two duplicate nodes which has same uuid + # but are both valid cached source. This cause caching race condition. + with enable_caching(identifier="aiida.calculations:quantumespresso.*"): + phonon_frequencies_builder: ProcessBuilder = ( + _ConvergencePhononFrequenciessWorkChain.get_builder( + pseudo=pseudo_path("Al"), + protocol="test", + cutoff_list=[(20, 80), (30, 120)], + configuration="DC", + pw_code=code_generator("pw"), + ph_code=code_generator("ph"), + clean_workdir=True, + ) + ) + + # Running a phonon_frequencies convergence workflow first and check that SCF is not from cached calcjob + _, source_node = run_get_node(phonon_frequencies_builder) + + # Make the source ph calculation invalid cache + source_ref_wf = [ + p + for p in source_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + source_ph_calcjob_node = source_ref_wf.called[1].called[0] + assert source_ph_calcjob_node.is_valid_cache + + source_ph_calcjob_node.is_valid_cache = False + + # Run again and check it is using caching + _, cached_node = run_get_node(phonon_frequencies_builder) + + cached_ref_wf = [ + p + for p in cached_node.called + if p.base.extras.get("wavefunction_cutoff", None) == 30 + ][0] + # Check the ph from rerun pw is finished okay + assert cached_ref_wf.is_finished_ok + + cached_ph_calcjob_node = cached_ref_wf.called[1].called[0] + assert ( + cached_ph_calcjob_node.base.extras.get("_aiida_cached_from", None) is None + ) + assert cached_ph_calcjob_node.exit_code.status == 312