diff --git a/setup.cfg b/setup.cfg index 0af921db..4b46adb7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ install_requires = numpy pandas protobuf<=3.20.1 + pydantic pyoos python-dateutil ruamel.yaml @@ -83,7 +84,6 @@ dev = types-python-dateutil [options.package_data] -* = *.yaml ewatercycle = py.typed [coverage:run] diff --git a/src/ewatercycle/config/_config_object.py b/src/ewatercycle/config/_config_object.py index ddaf145c..7b12e753 100644 --- a/src/ewatercycle/config/_config_object.py +++ b/src/ewatercycle/config/_config_object.py @@ -4,26 +4,69 @@ from io import StringIO from logging import getLogger from pathlib import Path -from typing import Optional, TextIO, Union +from typing import Any, Dict, Literal, Optional, Set, TextIO, Union +from pydantic import BaseModel, DirectoryPath, FilePath, root_validator from ruamel.yaml import YAML from ewatercycle.util import to_absolute_path -from ._validated_config import ValidatedConfig -from ._validators import _validators - logger = getLogger(__name__) -class Config(ValidatedConfig): +# TODO dont duplicate +# src/ewatercycle/parameter_sets/default.py:ParameterSet +# but fix circular dependency +class ParameterSetConfig(BaseModel): + # TODO prepend directory with CFG.parameterset_dir + # and make DirectoryPath type + directory: Path + # TODO prepend config with CFG.parameterset_dir and .directory + # and make FilePath type + config: Path + doi: str = "N/A" + target_model: str = "generic" + supported_model_versions: Set[str] = set() + + +class Config(BaseModel): """Configuration object. Do not instantiate this class directly, but use :obj:`ewatercycle.CFG` instead. """ - _validate = _validators + grdc_location: Optional[DirectoryPath] + container_engine: Literal["docker", "apptainer", "singularity"] = "docker" + apptainer_dir: Optional[DirectoryPath] + singularity_dir: Optional[DirectoryPath] + output_dir: Optional[DirectoryPath] + parameterset_dir: Optional[DirectoryPath] + parameter_sets: Dict[str, ParameterSetConfig] = {} + ewatercycle_config: Optional[FilePath] + + @root_validator + def _deprecate_singularity_dir(cls, values): + singularity_dir = values.get("singularity_dir") + apptainer_dir = values.get("apptainer_dir") + if singularity_dir is not None and apptainer_dir is None: + logger.warn("singularity_dir has been deprecated please use apptainer_dir") + values["apptainer_dir"] = singularity_dir + return values + + # TODO add more cross property validation like + # - When container engine is apptainer then apptainer_dir must be set + # - When parameter_sets is filled then parameterset_dir must be set + + # TODO drop dict methods and use CFG.bla instead of CFG['bla'] everywhere else + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + setattr(self, key, value) + + def __delitem__(self, key): + setattr(key, None) @classmethod def _load_user_config(cls, filename: Union[os.PathLike, str]) -> "Config": @@ -36,23 +79,21 @@ def _load_user_config(cls, filename: Union[os.PathLike, str]) -> "Config": filename: pathlike Name of the config file, must be yaml format """ - new = cls() + new: Dict[str, Any] = {} mapping = read_config_file(filename) mapping["ewatercycle_config"] = filename new.update(CFG_DEFAULT) new.update(mapping) - return new + return cls(**new) @classmethod def _load_default_config(cls, filename: Union[os.PathLike, str]) -> "Config": """Load the default configuration.""" - new = cls() mapping = read_config_file(filename) - new.update(mapping) - return new + return cls(**mapping) def load_from_file(self, filename: Union[os.PathLike, str]) -> None: """Load user configuration from the given file.""" @@ -76,15 +117,8 @@ def dump_to_yaml(self) -> str: return stream.getvalue() def _save_to_stream(self, stream: TextIO): - cp = self.copy() - # Exclude own path from dump - cp.pop("ewatercycle_config", None) - - cp["grdc_location"] = str(cp["grdc_location"]) - cp["apptainer_dir"] = str(cp["apptainer_dir"]) - cp["output_dir"] = str(cp["output_dir"]) - cp["parameterset_dir"] = str(cp["parameterset_dir"]) + cp = self.dict(exclude={"ewatercycle_config"}) yaml = YAML(typ="safe") yaml.dump(cp, stream) @@ -99,7 +133,7 @@ def save_to_file(self, config_file: Optional[Union[os.PathLike, str]] = None): the location in users home directory. """ # Exclude own path from dump - old_config_file = self.get("ewatercycle_config", None) + old_config_file = self.ewatercycle_config if config_file is None: config_file = ( @@ -154,7 +188,7 @@ def find_user_config(sources: tuple) -> Optional[os.PathLike]: USER_CONFIG = find_user_config(SOURCES) DEFAULT_CONFIG = Path(__file__).parent / FILENAME -CFG_DEFAULT = Config._load_default_config(DEFAULT_CONFIG) +CFG_DEFAULT = Config() if USER_CONFIG: CFG = Config._load_user_config(USER_CONFIG) diff --git a/src/ewatercycle/config/_validated_config.py b/src/ewatercycle/config/_validated_config.py deleted file mode 100644 index ae02089f..00000000 --- a/src/ewatercycle/config/_validated_config.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Config validation objects.""" - -import pprint -from collections.abc import MutableMapping -from typing import Callable, Dict - -from ._validators import ValidationError - - -class InvalidConfigParameter(Exception): - """Config parameter is invalid.""" - - -class MissingConfigParameter(UserWarning): - """Config parameter is missing.""" - - -# The code for this class was take from matplotlib (v3.3) and modified to -# fit the needs of eWaterCycle. Matplotlib is licenced under the terms of -# the the 'Python Software Foundation License' -# (https://www.python.org/psf/license) -class ValidatedConfig(MutableMapping): - """Based on `matplotlib.rcParams`.""" - - _validate: Dict[str, Callable] = {} - - # validate values on the way in - def __init__(self, *args, **kwargs): - super().__init__() - self._mapping = {} - self.update(*args, **kwargs) - - def __setitem__(self, key, val): - """Map key to value.""" - try: - cval = self._validate[key](val) - except ValidationError as verr: - raise InvalidConfigParameter(f"Key `{key}`: {verr}") from None - except KeyError: - raise InvalidConfigParameter( - f"`{key}` is not a valid config parameter." - ) from None - - self._mapping[key] = cval - - def __getitem__(self, key): - """Return value mapped by key.""" - return self._mapping[key] - - def __repr__(self): - """Return canonical string representation.""" - class_name = self.__class__.__name__ - indent = len(class_name) + 1 - repr_split = pprint.pformat(self._mapping, indent=1, width=80 - indent).split( - "\n" - ) - repr_indented = ("\n" + " " * indent).join(repr_split) - return "{}({})".format(class_name, repr_indented) - - def __str__(self): - """Return string representation.""" - return "\n".join(map("{0[0]}: {0[1]}".format, sorted(self._mapping.items()))) - - def __iter__(self): - """Yield sorted list of keys.""" - yield from sorted(self._mapping) - - def __len__(self): - """Return number of config keys.""" - return len(self._mapping) - - def __delitem__(self, key): - """Delete key/value from config.""" - del self._mapping[key] - - def copy(self): - """Copy the keys/values of this object to a dict.""" - return {k: self._mapping[k] for k in self} - - def clear(self): - """Clear Config.""" - self._mapping.clear() diff --git a/src/ewatercycle/config/ewatercycle.yaml b/src/ewatercycle/config/ewatercycle.yaml deleted file mode 100644 index 15f0e2e5..00000000 --- a/src/ewatercycle/config/ewatercycle.yaml +++ /dev/null @@ -1,6 +0,0 @@ -grdc_location: null -container_engine: null -apptainer_dir: null -output_dir: null -parameterset_dir: null -parameter_sets: null diff --git a/src/ewatercycle/parameter_sets/__init__.py b/src/ewatercycle/parameter_sets/__init__.py index 6227b083..801e9a32 100644 --- a/src/ewatercycle/parameter_sets/__init__.py +++ b/src/ewatercycle/parameter_sets/__init__.py @@ -15,10 +15,8 @@ def _parse_parametersets(): parametersets = {} - if CFG["parameter_sets"] is None: - return [] - for name, options in CFG["parameter_sets"].items(): - parameterset = ParameterSet(name=name, **options) + for name, options in CFG.parameter_sets.items(): + parameterset = ParameterSet(name=name, **options.dict()) parametersets[name] = parameterset return parametersets