Skip to content

Commit

Permalink
PyscfCalculation: Validate parameters for unknown arguments
Browse files Browse the repository at this point in the history
The validator for the `parameters` input is updated to check for any
unknown arguments. This now raises an exception instead of silently
ignoring them and starting the calculation.
  • Loading branch information
sphuber committed Oct 24, 2023
1 parent 671245e commit 590d0ca
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 10 deletions.
24 changes: 14 additions & 10 deletions src/aiida_pyscf/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,15 @@ def define(cls, spec: CalcJobProcessSpec): # type: ignore[override]
)

@classmethod
def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: disable=too-many-return-statements,too-many-branches
def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: disable=too-many-return-statements,too-many-branches,too-many-locals
"""Validate the parameters input."""
if not value:
return None

parameters = value.get_dict()

mean_field_method = parameters.get('mean_field', {}).get('method')
mean_field = parameters.pop('mean_field', {})
mean_field_method = mean_field.pop('method', None)
valid_methods = ['RKS', 'RHF', 'DKS', 'DHF', 'GKS', 'GHF', 'HF', 'KS', 'ROHF', 'ROKS', 'UKS', 'UHF']
options = ' '.join(valid_methods)

Expand All @@ -145,24 +146,24 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di
if mean_field_method not in valid_methods:
return f'Specified mean field method {mean_field_method} is not supported, choose from: {options}'

if 'chkfile' in parameters.get('mean_field', {}):
if 'chkfile' in mean_field:
return (
'The `chkfile` cannot be specified in the `mean_field` parameters. It is set automatically by the '
'plugin if the `checkpoint` input is provided.'
)

if 'optimizer' in parameters:
if optimizer := parameters.pop('optimizer', None):
valid_solvers = ('geometric', 'berny')
solver = parameters['optimizer'].get('solver')
solver = optimizer.get('solver')

if solver is None:
return f'No solver specified in `optimizer` parameters. Choose from: {valid_solvers}'

if solver.lower() not in valid_solvers:
return f'Invalid solver `{solver}` specified in `optimizer` parameters. Choose from: {valid_solvers}'

if 'cubegen' in parameters:
orbitals = parameters['cubegen'].get('orbitals')
if cubegen := parameters.pop('cubegen', None):
orbitals = cubegen.get('orbitals')
indices = orbitals.get('indices') if orbitals is not None else None

if orbitals is not None and indices is None:
Expand All @@ -174,9 +175,9 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di
if indices is not None and (not isinstance(indices, list) or any(not isinstance(e, int) for e in indices)):
return f'The `cubegen.orbitals.indices` parameter should be a list of integers, but got: {indices}'

if 'fcidump' in parameters:
active_spaces = parameters['fcidump'].get('active_spaces')
occupations = parameters['fcidump'].get('occupations')
if fcidump := parameters.pop('fcidump', None):
active_spaces = fcidump.get('active_spaces')
occupations = fcidump.get('occupations')
arrays = []

for key, data in (('active_spaces', active_spaces), ('occupations', occupations)):
Expand All @@ -193,6 +194,9 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di
if arrays[0].shape != arrays[1].shape:
return 'The `fcipdump.active_spaces` and `fcipdump.occupations` arrays have different shapes.'

if unknown_keys := list(parameters.keys()):
return f'The following arguments are not supported: {", ".join(unknown_keys)}'

def get_template_environment(self) -> Environment:
"""Return the template environment that should be used for rendering.
Expand Down
9 changes: 9 additions & 0 deletions tests/calculations/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,15 @@ def test_invalid_parameters_mean_field_chkfile(generate_calc_job, generate_input
generate_calc_job(PyscfCalculation, inputs=inputs)


def test_invalid_parameters_unknown_arguments(generate_calc_job, generate_inputs_pyscf):
"""Test validation of ``parameters`` raises if unknown arguments are included."""
parameters = {'unknown_key': 'value'}
inputs = generate_inputs_pyscf(parameters=parameters)

with pytest.raises(ValueError, match=r'The following arguments are not supported: unknown_key'):
generate_calc_job(PyscfCalculation, inputs=inputs)


@pytest.mark.parametrize(
'parameters, expected', (
({}, r'No solver specified in `optimizer` parameters'),
Expand Down

0 comments on commit 590d0ca

Please sign in to comment.