diff --git a/src/aiida_pyscf/calculations/base.py b/src/aiida_pyscf/calculations/base.py index e2e8cde..573d4e4 100644 --- a/src/aiida_pyscf/calculations/base.py +++ b/src/aiida_pyscf/calculations/base.py @@ -2,7 +2,6 @@ """``CalcJob`` plugin for PySCF.""" from __future__ import annotations -import copy import io import numbers import pathlib @@ -130,15 +129,23 @@ def define(cls, spec: CalcJobProcessSpec): # type: ignore[override] ) @classmethod - def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0911, PLR0912 + def get_valid_parameter_keys(cls) -> tuple[str, ...]: + """Return list of valid keys for the ``parameters`` input.""" + return ('mean_field', 'localize_orbitals', 'optimizer', 'cubegen', 'fcidump', 'hessian', 'results', 'structure') + + @classmethod + def validate_parameters(cls, parameters: Dict | None, _) -> str | None: # noqa: PLR0911, PLR0912 """Validate the parameters input.""" - if not value: + if not parameters: return None - parameters = copy.deepcopy(value.get_dict()) + unsupported_keys = set(parameters.keys()).difference(set(cls.get_valid_parameter_keys())) + + if unsupported_keys: + return f'The following arguments are not supported: {", ".join(unsupported_keys)}' - mean_field = parameters.pop('mean_field', {}) - mean_field_method = mean_field.pop('method', None) + mean_field = parameters.get('mean_field', {}) + mean_field_method = mean_field.get('method') valid_methods = ['RKS', 'RHF', 'DKS', 'DHF', 'GKS', 'GHF', 'HF', 'KS', 'ROHF', 'ROKS', 'UKS', 'UHF'] options = ' '.join(valid_methods) @@ -154,7 +161,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 'plugin if the `checkpoint` input is provided.' ) - if (localize_orbitals := parameters.pop('localize_orbitals', None)) is not None: + if (localize_orbitals := parameters.get('localize_orbitals')) is not None: valid_lo = ('boys', 'cholesky', 'edmiston', 'iao', 'ibo', 'lowdin', 'nao', 'orth', 'pipek', 'vvo') method = localize_orbitals.get('method') if method is None: @@ -163,7 +170,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 if method.lower() not in valid_lo: return f'Invalid method `{method}` specified in `localize_orbitals` parameters. Choose from: {valid_lo}' - if (optimizer := parameters.pop('optimizer', None)) is not None: + if (optimizer := parameters.get('optimizer')) is not None: valid_solvers = ('geometric', 'berny') solver = optimizer.get('solver') @@ -173,7 +180,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 if solver.lower() not in valid_solvers: return f'Invalid solver `{solver}` specified in `optimizer` parameters. Choose from: {valid_solvers}' - if (cubegen := parameters.pop('cubegen', None)) is not None: + if (cubegen := parameters.get('cubegen')) is not None: orbitals = cubegen.get('orbitals') indices = orbitals.get('indices') if orbitals is not None else None @@ -186,7 +193,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 if indices is not None and (not isinstance(indices, list) or any(not isinstance(e, int) for e in indices)): return f'The `cubegen.orbitals.indices` parameter should be a list of integers, but got: {indices}' - if (fcidump := parameters.pop('fcidump', None)) is not None: + if (fcidump := parameters.get('fcidump')) is not None: active_spaces = fcidump.get('active_spaces') occupations = fcidump.get('occupations') arrays = [] @@ -205,13 +212,6 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0 if arrays[0].shape != arrays[1].shape: return 'The `fcipdump.active_spaces` and `fcipdump.occupations` arrays have different shapes.' - # Remove other known arguments - for key in ('hessian', 'results', 'structure'): - parameters.pop(key, None) - - if unknown_keys := list(parameters.keys()): - return f'The following arguments are not supported: {", ".join(unknown_keys)}' - def get_template_environment(self) -> Environment: """Return the template environment that should be used for rendering.