Skip to content

Commit

Permalink
Add support for different DIIS schemes.
Browse files Browse the repository at this point in the history
Previously all mean field option were rendered as single quoted strings.
For the special case of setting the DIIS scheme, this would be rendered
as `mean_field.DIIS = scf.'ADIIS'` , and would break.

This commit adds special handling for the 'scf.solver' key.
  • Loading branch information
ConradJohnston committed Nov 26, 2024
1 parent 684b4c2 commit 4865ba0
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 2 deletions.
16 changes: 16 additions & 0 deletions src/aiida_pyscf/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0
'plugin if the `checkpoint` input is provided.'
)

if (scf_solver := mean_field.get('solver')) is not None:
valid_scf_solvers = ('DIIS', 'CDIIS', 'EDIIS', 'ADIIS')
options = ' '.join(valid_scf_solvers)

if scf_solver is None:
return f'No solver specified in `mean_field.solver` parameters. Choose from: {options}'

if scf_solver.upper() == 'DIIS':
scf_solver = 'CDIIS'
return '`DIIS` is an alias for CDIIS in PySCF. Using `CDIIS` explicitly instead.'

if scf_solver.upper() not in valid_scf_solvers:
return (
f'Invalid solver `{scf_solver}` specified in `mean_field.solver` parameters. Choose from: {options}'
)

if (localize_orbitals := parameters.get('localize_orbitals')) is not None:
valid_lo = ('boys', 'cholesky', 'edmiston', 'iao', 'ibo', 'lowdin', 'nao', 'orth', 'pipek', 'vvo')
method = localize_orbitals.get('method')
Expand Down
5 changes: 4 additions & 1 deletion src/aiida_pyscf/calculations/templates/mean_field.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
from pyscf import scf
mean_field = scf.{{ mean_field.method }}(structure)
{% if mean_field %}
{% if mean_field.solver %}
mean_field.DIIS = scf.{{ mean_field.solver}}
{% endif %}
{% for key, value in mean_field.items() %}
{% if key not in ['method', 'checkpoint'] %}
{% if key not in ['checkpoint', 'method', 'solver'] %}
mean_field.{{ key }} = {{ value|render_python }}
{% endif %}
{% endfor %}
Expand Down
23 changes: 22 additions & 1 deletion tests/calculations/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def test_parameters_mean_field(generate_calc_job, generate_inputs_pyscf, file_re
'mean_field': {
'diis_start_cycle': 2,
'method': 'RHF',
'solver': 'ADIIS',
'grids': {'level': 3},
'xc': 'PBE',
},
Expand Down Expand Up @@ -186,7 +187,27 @@ def test_invalid_parameters_mean_field_method(generate_calc_job, generate_inputs
parameters = {'mean_field': {'method': 'invalid'}}
inputs = generate_inputs_pyscf(parameters=parameters)

with pytest.raises(ValueError, match=r'Specified mean field method invalid is not supported'):
with pytest.raises(ValueError, match=r'Specified mean field method invalid is not supported, choose from: '
r'RKS RHF DKS DHF GKS GHF HF KS ROHF ROKS UKS UHF'):
generate_calc_job(PyscfCalculation, inputs=inputs)


def test_invalid_parameters_mean_field_solver(generate_calc_job, generate_inputs_pyscf):
"""Test validation of ``parameters.mean_field.solver``."""
parameters = {'mean_field': {'solver': 'invalid'}}
inputs = generate_inputs_pyscf(parameters=parameters)

with pytest.raises(ValueError, match=r'Invalid solver `invalid` specified in `mean_field.solver` parameters. '
r'Choose from: DIIS CDIIS EDIIS ADIIS'):
generate_calc_job(PyscfCalculation, inputs=inputs)


def test_invalid_parameters_mean_field_solver_diis(generate_calc_job, generate_inputs_pyscf):
"""Test logic to catch `DIIS` solver input for ``parameters.mean_field.solver``."""
parameters = {'mean_field': {'solver': 'DIIS'}}
inputs = generate_inputs_pyscf(parameters=parameters)

with pytest.raises(ValueError, match=r'`DIIS` is an alias for CDIIS in PySCF. Using `CDIIS` explicitly instead.'):
generate_calc_job(PyscfCalculation, inputs=inputs)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def main():
# Section: Mean field
from pyscf import scf
mean_field = scf.RHF(structure)
mean_field.DIIS = scf.ADIIS
mean_field.xc = 'PBE'
mean_field.grids = {'level': 3}
mean_field.diis_start_cycle = 2
Expand Down

0 comments on commit 4865ba0

Please sign in to comment.