Skip to content

Commit

Permalink
PyscfCalculation: Make valid keys in parameters easy to modify
Browse files Browse the repository at this point in the history
The `validate_parameters` was recently updated to check for unknown keys
being passed. Subclasses may want to add additional keys so the tuple of
known keys is abstracted and returned by the `get_valid_parameter_keys`
classmethod.

The tuple of supported keys is now made complete such that the
`validate_parameters` implementation can be cleaned up and move the
validation of the keys to the beginning and a copy of the parameters is
no longer needed since it no longer needs to pop known keys.
  • Loading branch information
sphuber committed Jan 3, 2024
1 parent 51431e3 commit 5b64b1b
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/aiida_pyscf/calculations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
"""``CalcJob`` plugin for PySCF."""
from __future__ import annotations

import copy
import io
import numbers
import pathlib
Expand Down Expand Up @@ -130,15 +129,23 @@ def define(cls, spec: CalcJobProcessSpec): # type: ignore[override]
)

@classmethod
def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0911, PLR0912
def get_valid_parameter_keys(cls) -> tuple[str, ...]:
"""Return list of valid keys for the ``parameters`` input."""
return ('mean_field', 'localize_orbitals', 'optimizer', 'cubegen', 'fcidump', 'hessian', 'results', 'structure')

@classmethod
def validate_parameters(cls, parameters: Dict | None, _) -> str | None: # noqa: PLR0911, PLR0912
"""Validate the parameters input."""
if not value:
if not parameters:
return None

parameters = copy.deepcopy(value.get_dict())
unsupported_keys = set(parameters.keys()).difference(set(cls.get_valid_parameter_keys()))

if unsupported_keys:
return f'The following arguments are not supported: {", ".join(unsupported_keys)}'

mean_field = parameters.pop('mean_field', {})
mean_field_method = mean_field.pop('method', None)
mean_field = parameters.get('mean_field', {})
mean_field_method = mean_field.get('method')
valid_methods = ['RKS', 'RHF', 'DKS', 'DHF', 'GKS', 'GHF', 'HF', 'KS', 'ROHF', 'ROKS', 'UKS', 'UHF']
options = ' '.join(valid_methods)

Expand All @@ -154,7 +161,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0
'plugin if the `checkpoint` input is provided.'
)

if (localize_orbitals := parameters.pop('localize_orbitals', None)) is not None:
if (localize_orbitals := parameters.get('localize_orbitals')) is not None:
valid_lo = ('boys', 'cholesky', 'edmiston', 'iao', 'ibo', 'lowdin', 'nao', 'orth', 'pipek', 'vvo')
method = localize_orbitals.get('method')
if method is None:
Expand All @@ -163,7 +170,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0
if method.lower() not in valid_lo:
return f'Invalid method `{method}` specified in `localize_orbitals` parameters. Choose from: {valid_lo}'

if (optimizer := parameters.pop('optimizer', None)) is not None:
if (optimizer := parameters.get('optimizer')) is not None:
valid_solvers = ('geometric', 'berny')
solver = optimizer.get('solver')

Expand All @@ -173,7 +180,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0
if solver.lower() not in valid_solvers:
return f'Invalid solver `{solver}` specified in `optimizer` parameters. Choose from: {valid_solvers}'

if (cubegen := parameters.pop('cubegen', None)) is not None:
if (cubegen := parameters.get('cubegen')) is not None:
orbitals = cubegen.get('orbitals')
indices = orbitals.get('indices') if orbitals is not None else None

Expand All @@ -186,7 +193,7 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0
if indices is not None and (not isinstance(indices, list) or any(not isinstance(e, int) for e in indices)):
return f'The `cubegen.orbitals.indices` parameter should be a list of integers, but got: {indices}'

if (fcidump := parameters.pop('fcidump', None)) is not None:
if (fcidump := parameters.get('fcidump')) is not None:
active_spaces = fcidump.get('active_spaces')
occupations = fcidump.get('occupations')
arrays = []
Expand All @@ -205,13 +212,6 @@ def validate_parameters(cls, value: Dict | None, _) -> str | None: # noqa: PLR0
if arrays[0].shape != arrays[1].shape:
return 'The `fcipdump.active_spaces` and `fcipdump.occupations` arrays have different shapes.'

# Remove other known arguments
for key in ('hessian', 'results', 'structure'):
parameters.pop(key, None)

if unknown_keys := list(parameters.keys()):
return f'The following arguments are not supported: {", ".join(unknown_keys)}'

def get_template_environment(self) -> Environment:
"""Return the template environment that should be used for rendering.
Expand Down

0 comments on commit 5b64b1b

Please sign in to comment.