From 590d0ca7f113d6d35572006bcf41bbd51e07199c Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Wed, 25 Oct 2023 00:35:52 +0200 Subject: [PATCH] `PyscfCalculation`: Validate `parameters` for unknown arguments The validator for the `parameters` input is updated to check for any unknown arguments. This now raises an exception instead of silently ignoring them and starting the calculation. --- src/aiida_pyscf/calculations/base.py | 24 ++++++++++++++---------- tests/calculations/test_base.py | 9 +++++++++ 2 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/aiida_pyscf/calculations/base.py b/src/aiida_pyscf/calculations/base.py index df0dc27..eb3ff64 100644 --- a/src/aiida_pyscf/calculations/base.py +++ b/src/aiida_pyscf/calculations/base.py @@ -128,14 +128,15 @@ def define(cls, spec: CalcJobProcessSpec): # type: ignore[override] ) @classmethod - def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: disable=too-many-return-statements,too-many-branches + def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: disable=too-many-return-statements,too-many-branches,too-many-locals """Validate the parameters input.""" if not value: return None parameters = value.get_dict() - mean_field_method = parameters.get('mean_field', {}).get('method') + mean_field = parameters.pop('mean_field', {}) + mean_field_method = mean_field.pop('method', None) valid_methods = ['RKS', 'RHF', 'DKS', 'DHF', 'GKS', 'GHF', 'HF', 'KS', 'ROHF', 'ROKS', 'UKS', 'UHF'] options = ' '.join(valid_methods) @@ -145,15 +146,15 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di if mean_field_method not in valid_methods: return f'Specified mean field method {mean_field_method} is not supported, choose from: {options}' - if 'chkfile' in parameters.get('mean_field', {}): + if 'chkfile' in mean_field: return ( 'The `chkfile` cannot be specified in the `mean_field` parameters. It is set automatically by the ' 'plugin if the `checkpoint` input is provided.' ) - if 'optimizer' in parameters: + if optimizer := parameters.pop('optimizer', None): valid_solvers = ('geometric', 'berny') - solver = parameters['optimizer'].get('solver') + solver = optimizer.get('solver') if solver is None: return f'No solver specified in `optimizer` parameters. Choose from: {valid_solvers}' @@ -161,8 +162,8 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di if solver.lower() not in valid_solvers: return f'Invalid solver `{solver}` specified in `optimizer` parameters. Choose from: {valid_solvers}' - if 'cubegen' in parameters: - orbitals = parameters['cubegen'].get('orbitals') + if cubegen := parameters.pop('cubegen', None): + orbitals = cubegen.get('orbitals') indices = orbitals.get('indices') if orbitals is not None else None if orbitals is not None and indices is None: @@ -174,9 +175,9 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di if indices is not None and (not isinstance(indices, list) or any(not isinstance(e, int) for e in indices)): return f'The `cubegen.orbitals.indices` parameter should be a list of integers, but got: {indices}' - if 'fcidump' in parameters: - active_spaces = parameters['fcidump'].get('active_spaces') - occupations = parameters['fcidump'].get('occupations') + if fcidump := parameters.pop('fcidump', None): + active_spaces = fcidump.get('active_spaces') + occupations = fcidump.get('occupations') arrays = [] for key, data in (('active_spaces', active_spaces), ('occupations', occupations)): @@ -193,6 +194,9 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # pylint: di if arrays[0].shape != arrays[1].shape: return 'The `fcipdump.active_spaces` and `fcipdump.occupations` arrays have different shapes.' + if unknown_keys := list(parameters.keys()): + return f'The following arguments are not supported: {", ".join(unknown_keys)}' + def get_template_environment(self) -> Environment: """Return the template environment that should be used for rendering. diff --git a/tests/calculations/test_base.py b/tests/calculations/test_base.py index 5b3a9de..05f4b4b 100644 --- a/tests/calculations/test_base.py +++ b/tests/calculations/test_base.py @@ -186,6 +186,15 @@ def test_invalid_parameters_mean_field_chkfile(generate_calc_job, generate_input generate_calc_job(PyscfCalculation, inputs=inputs) +def test_invalid_parameters_unknown_arguments(generate_calc_job, generate_inputs_pyscf): + """Test validation of ``parameters`` raises if unknown arguments are included.""" + parameters = {'unknown_key': 'value'} + inputs = generate_inputs_pyscf(parameters=parameters) + + with pytest.raises(ValueError, match=r'The following arguments are not supported: unknown_key'): + generate_calc_job(PyscfCalculation, inputs=inputs) + + @pytest.mark.parametrize( 'parameters, expected', ( ({}, r'No solver specified in `optimizer` parameters'),