From 144d63bc52717a81c374c6b0202a96c4a5cab005 Mon Sep 17 00:00:00 2001 From: ElliottKasoar <45317199+ElliottKasoar@users.noreply.github.com> Date: Thu, 25 Jul 2024 15:09:26 +0100 Subject: [PATCH] Update Atoms info from calculated results (#221) * Ensure atoms info is consistent with files written * Add write-kwargs to MD * Add arch to Atoms.info and default writing results --------- Co-authored-by: Jacob Wilkins <46597752+oerc0122@users.noreply.github.com> --- janus_core/calculations/eos.py | 32 +++--- janus_core/calculations/geom_opt.py | 27 +++-- janus_core/calculations/md.py | 55 +++++++---- janus_core/calculations/single_point.py | 125 +++++------------------- janus_core/cli/eos.py | 2 +- janus_core/cli/geomopt.py | 2 +- janus_core/cli/md.py | 8 ++ janus_core/helpers/janus_types.py | 12 ++- janus_core/helpers/mlip_calculators.py | 1 + janus_core/helpers/utils.py | 105 +++++++++++++++++++- tests/test_correlator.py | 65 ++++++------ tests/test_geom_opt.py | 2 +- tests/test_geomopt_cli.py | 6 +- tests/test_md_cli.py | 54 ++++++++++ tests/test_single_point.py | 16 +++ tests/test_singlepoint_cli.py | 24 +++++ tests/test_utils.py | 94 +++++++++++++++++- 17 files changed, 444 insertions(+), 186 deletions(-) diff --git a/janus_core/calculations/eos.py b/janus_core/calculations/eos.py index 6db9a02a..d5b4b422 100644 --- a/janus_core/calculations/eos.py +++ b/janus_core/calculations/eos.py @@ -5,16 +5,15 @@ from ase import Atoms from ase.eos import EquationOfState -from ase.io import write from ase.units import kJ from codecarbon import OfflineEmissionsTracker from numpy import float64, linspace from numpy.typing import NDArray from janus_core.calculations.geom_opt import optimize -from janus_core.helpers.janus_types import ASEWriteArgs, EoSNames, EoSResults, PathLike +from janus_core.helpers.janus_types import EoSNames, EoSResults, OutputKwargs, PathLike from janus_core.helpers.log import config_logger, config_tracker -from janus_core.helpers.utils import none_to_dict +from janus_core.helpers.utils import none_to_dict, output_structs def _calc_volumes_energies( # pylint: disable=too-many-locals @@ -26,7 +25,7 @@ def _calc_volumes_energies( # pylint: disable=too-many-locals minimize_all: bool = False, minimize_kwargs: Optional[dict[str, Any]] = None, write_structures: bool = False, - write_kwargs: Optional[ASEWriteArgs] = None, + write_kwargs: Optional[OutputKwargs] = None, logger: Optional[Logger] = None, tracker: Optional[OfflineEmissionsTracker] = None, ) -> tuple[NDArray[float64], list[float], list[float]]: @@ -50,7 +49,7 @@ def _calc_volumes_energies( # pylint: disable=too-many-locals chemical formula of the structure. write_structures : bool True to write out all genereated structures. Default is False. - write_kwargs : Optional[ASEWriteArgs], + write_kwargs : Optional[OutputKwargs], Keyword arguments to pass to ase.io.write to save generated structures. Default is {}. logger : Optional[Logger] @@ -88,10 +87,15 @@ def _calc_volumes_energies( # pylint: disable=too-many-locals volumes.append(c_struct.get_volume()) energies.append(c_struct.get_potential_energy()) - if write_structures: - # Always append first original structure - write_kwargs["append"] = True - write(images=c_struct, **write_kwargs) + # Always append first original structure + write_kwargs["append"] = True + # Write structures, but no need to set info c_struct is not used elsewhere + output_structs( + images=c_struct, + write_results=write_structures, + set_info=False, + write_kwargs=write_kwargs, + ) if logger: tracker.stop_task() @@ -113,7 +117,7 @@ def calc_eos( minimize_kwargs: Optional[dict[str, Any]] = None, write_results: bool = True, write_structures: bool = False, - write_kwargs: Optional[ASEWriteArgs] = None, + write_kwargs: Optional[OutputKwargs] = None, file_prefix: Optional[PathLike] = None, log_kwargs: Optional[dict[str, Any]] = None, tracker_kwargs: Optional[dict[str, Any]] = None, @@ -145,7 +149,7 @@ def calc_eos( True to write out results of equation of state calculations. Default is True. write_structures : bool True to write out all genereated structures. Default is False. - write_kwargs : Optional[ASEWriteArgs], + write_kwargs : Optional[OutputKwargs], Keyword arguments to pass to ase.io.write to save generated structures. Default is {}. file_prefix : Optional[PathLike] @@ -205,8 +209,10 @@ def calc_eos( } optimize(struct, **minimize_kwargs) - if write_structures: - write(images=struct, **write_kwargs) + # Optionally write structure to file + output_structs( + images=struct, write_results=write_structures, write_kwargs=write_kwargs + ) # Set constant volume for geometry optimization of generated structures if "filter_kwargs" in minimize_kwargs: diff --git a/janus_core/calculations/geom_opt.py b/janus_core/calculations/geom_opt.py index 431572c3..16e8e1be 100644 --- a/janus_core/calculations/geom_opt.py +++ b/janus_core/calculations/geom_opt.py @@ -7,15 +7,15 @@ from ase import Atoms, filters, units from ase.filters import FrechetCellFilter -from ase.io import read, write +from ase.io import read import ase.optimize from ase.optimize import LBFGS from ase.optimize.optimize import Optimizer from numpy import linalg -from janus_core.helpers.janus_types import ASEOptArgs, ASEWriteArgs +from janus_core.helpers.janus_types import ASEOptArgs, OutputKwargs from janus_core.helpers.log import config_logger, config_tracker -from janus_core.helpers.utils import none_to_dict, spacegroup +from janus_core.helpers.utils import none_to_dict, output_structs, spacegroup def _set_functions( @@ -126,8 +126,8 @@ def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-bra optimizer: Callable = LBFGS, opt_kwargs: Optional[ASEOptArgs] = None, write_results: bool = False, - write_kwargs: Optional[ASEWriteArgs] = None, - traj_kwargs: Optional[ASEWriteArgs] = None, + write_kwargs: Optional[OutputKwargs] = None, + traj_kwargs: Optional[OutputKwargs] = None, log_kwargs: Optional[dict[str, Any]] = None, tracker_kwargs: Optional[dict[str, Any]] = None, ) -> Atoms: @@ -161,10 +161,10 @@ def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-bra Keyword arguments to pass to optimizer. Default is {}. write_results : bool True to write out optimized structure. Default is False. - write_kwargs : Optional[ASEWriteArgs], + write_kwargs : Optional[OutputKwargs], Keyword arguments to pass to ase.io.write to save optimized structure. Default is {}. - traj_kwargs : Optional[ASEWriteArgs] + traj_kwargs : Optional[OutputKwargs] Keyword arguments to pass to ase.io.write to save optimization trajectory. Must include "filename" keyword. Default is {}. log_kwargs : Optional[dict[str, Any]] @@ -239,13 +239,20 @@ def optimize( # pylint: disable=too-many-arguments,too-many-locals,too-many-bra ) # Write out optimized structure - if write_results: - write(images=struct, **write_kwargs) + output_structs( + struct, + write_results=write_results, + write_kwargs=write_kwargs, + ) # Reformat trajectory file from binary if traj_kwargs: traj = read(opt_kwargs["trajectory"], index=":") - write(images=traj, **traj_kwargs) + output_structs( + traj, + write_results=True, + write_kwargs=traj_kwargs, + ) if logger: tracker.stop() diff --git a/janus_core/calculations/md.py b/janus_core/calculations/md.py index 0914321e..b6db34e8 100644 --- a/janus_core/calculations/md.py +++ b/janus_core/calculations/md.py @@ -12,7 +12,7 @@ from ase import Atoms, units from ase.geometry.analysis import Analysis -from ase.io import read, write +from ase.io import read from ase.md.langevin import Langevin from ase.md.npt import NPT as ASE_NPT from ase.md.velocitydistribution import ( @@ -29,12 +29,13 @@ from janus_core.helpers.janus_types import ( CorrelationKwargs, Ensembles, + OutputKwargs, PathLike, PostProcessKwargs, ) from janus_core.helpers.log import config_logger, config_tracker from janus_core.helpers.post_process import compute_rdf, compute_vaf -from janus_core.helpers.utils import FileNameMixin +from janus_core.helpers.utils import FileNameMixin, output_structs DENS_FACT = (units.m / 1.0e2) ** 3 / units.mol @@ -111,6 +112,9 @@ class MolecularDynamics(FileNameMixin): # pylint: disable=too-many-instance-att heating. temp_time : Optional[float] Time between heating steps, in fs. Default is None, which disables heating. + write_kwargs : Optional[OutputKwargs], + Keyword arguments to pass to `output_structs` when saving trajectory and final + files. Default is {}. post_process_kwargs : Optional[PostProcessKwargs] Keyword arguments to control post-processing operations. correlation_kwargs : Optional[CorrelationKwargs] @@ -182,6 +186,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta temp_end: Optional[float] = None, temp_step: Optional[float] = None, temp_time: Optional[float] = None, + write_kwargs: Optional[OutputKwargs] = None, post_process_kwargs: Optional[PostProcessKwargs] = None, correlation_kwargs: Optional[list[CorrelationKwargs]] = None, log_kwargs: Optional[dict[str, Any]] = None, @@ -262,6 +267,9 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta disables heating. temp_time : Optional[float] Time between heating steps, in fs. Default is None, which disables heating. + write_kwargs : Optional[OutputKwargs], + Keyword arguments to pass to `output_structs` when saving trajectory and + final files. Default is {}. post_process_kwargs : Optional[PostProcessKwargs] Keyword arguments to control post-processing operations. correlation_kwargs : Optional[list[CorrelationKwargs]] @@ -300,6 +308,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta self.temp_end = temp_end self.temp_step = temp_step self.temp_time = temp_time * units.fs if temp_time else None + self.write_kwargs = write_kwargs if write_kwargs is not None else {} self.post_process_kwargs = ( post_process_kwargs if post_process_kwargs is not None else {} ) @@ -312,6 +321,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals,too-many-sta FileNameMixin.__init__(self, struct, struct_name, file_prefix, ensemble) + self.write_kwargs.setdefault( + "columns", ["symbols", "positions", "momenta", "masses"] + ) + if "append" in self.write_kwargs: + raise ValueError("`append` cannot be specified when writing files") + self.log_kwargs = ( log_kwargs if log_kwargs else {} ) # pylint: disable=duplicate-code @@ -595,11 +610,13 @@ def _write_traj(self) -> None: self.dyn.nsteps > self.traj_start + self.traj_start % self.traj_every ) - self.dyn.atoms.write( - self.traj_file, - write_info=True, - columns=["symbols", "positions", "momenta", "masses"], - append=append, + write_kwargs = self.write_kwargs + write_kwargs["filename"] = self.traj_file + write_kwargs["append"] = append + output_structs( + images=self.struct, + write_results=True, + write_kwargs=write_kwargs, ) def _write_final_state(self) -> None: @@ -611,12 +628,13 @@ def _write_final_state(self) -> None: # Append if final file has been created append = self.created_final_file - write( - self.final_file, - self.struct, - write_info=True, - columns=["symbols", "positions", "momenta", "masses"], - append=append, + write_kwargs = self.write_kwargs + write_kwargs["filename"] = self.final_file + write_kwargs["append"] = append + output_structs( + images=self.struct, + write_results=True, + write_kwargs=write_kwargs, ) def _post_process(self) -> None: @@ -708,11 +726,12 @@ def _write_restart(self) -> None: """Write restart file and (optionally) rotate files saved.""" step = self.offset + self.dyn.nsteps if step > 0: - write( - self._restart_file, - self.struct, - write_info=True, - columns=["symbols", "positions", "momenta", "masses"], + write_kwargs = self.write_kwargs + write_kwargs["filename"] = self._restart_file + output_structs( + images=self.struct, + write_results=True, + write_kwargs=write_kwargs, ) if self.rotate_restart: self.restart_files.append(self._restart_file) diff --git a/janus_core/calculations/single_point.py b/janus_core/calculations/single_point.py index cf035f62..c8a1da92 100644 --- a/janus_core/calculations/single_point.py +++ b/janus_core/calculations/single_point.py @@ -1,27 +1,28 @@ """Prepare and perform single point calculations.""" -from collections.abc import Collection +from collections.abc import Sequence from copy import deepcopy from pathlib import Path -from typing import Any, Optional +from typing import Any, Optional, get_args from ase import Atoms -from ase.io import read, write -from numpy import isfinite, ndarray +from ase.io import read +from numpy import ndarray from janus_core.helpers.janus_types import ( Architectures, ASEReadArgs, - ASEWriteArgs, CalcResults, Devices, MaybeList, MaybeSequence, + OutputKwargs, PathLike, + Properties, ) from janus_core.helpers.log import config_logger, config_tracker from janus_core.helpers.mlip_calculators import choose_calculator -from janus_core.helpers.utils import FileNameMixin, none_to_dict +from janus_core.helpers.utils import FileNameMixin, none_to_dict, output_structs class SinglePoint(FileNameMixin): # pylint: disable=too-many-instance-attributes @@ -218,7 +219,7 @@ def set_calculator( read_kwargs = read_kwargs if read_kwargs else {} self.read_structure(**read_kwargs) - if isinstance(self.struct, list): + if isinstance(self.struct, Sequence): for struct in self.struct: struct.calc = deepcopy(calculator) # Return single Atoms object if only one image in list @@ -236,15 +237,11 @@ def _get_potential_energy(self) -> MaybeList[float]: MaybeList[float] Potential energy of structure(s). """ - tag = f"{self.architecture}_energy" - if isinstance(self.struct, list): + if isinstance(self.struct, Sequence): energies = [struct.get_potential_energy() for struct in self.struct] - for struct, energy in zip(self.struct, energies): - struct.info[tag] = energy return energies energy = self.struct.get_potential_energy() - self.struct.info[tag] = energy return energy def _get_forces(self) -> MaybeList[ndarray]: @@ -256,15 +253,11 @@ def _get_forces(self) -> MaybeList[ndarray]: MaybeList[ndarray] Forces of structure(s). """ - tag = f"{self.architecture}_forces" - if isinstance(self.struct, list): + if isinstance(self.struct, Sequence): forces = [struct.get_forces() for struct in self.struct] - for struct, force in zip(self.struct, forces): - struct.arrays[tag] = force return forces force = self.struct.get_forces() - self.struct.arrays[tag] = force return force def _get_stress(self) -> MaybeList[ndarray]: @@ -276,91 +269,18 @@ def _get_stress(self) -> MaybeList[ndarray]: MaybeList[ndarray] Stress of structure(s). """ - tag = f"{self.architecture}_stress" - if isinstance(self.struct, list): + if isinstance(self.struct, Sequence): stresses = [struct.get_stress() for struct in self.struct] - for struct, stress in zip(self.struct, stresses): - struct.info[tag] = stress return stresses stress = self.struct.get_stress() - self.struct.info[tag] = stress return stress - def _remove_invalid_props( - self, - struct: Atoms, - results: CalcResults = None, - properties: Collection[str] = (), - ) -> None: - """ - Remove any invalid properties from calculated results. - - Parameters - ---------- - struct : Atoms - ASE Atoms structure with attached calculator results. - results : CalcResults - Dictionary of calculated results. Default is {}. - properties : Collection[str] - Physical properties requested to be calculated. Default is (). - """ - results = results if results else {} - - # Find any properties with non-finite values - rm_keys = [ - prop - for prop in struct.calc.results - if not isfinite(struct.calc.results[prop]).all() - ] - # Raise error if property was explicitly requested, otherwise remove - for prop in rm_keys: - if prop in properties: - raise ValueError( - f"'{prop}' contains non-finite values for this structure." - ) - if prop in results: - del struct.info[f"{self.architecture}_{prop}"] - del struct.calc.results[prop] - del results[prop] - - def _clean_results( - self, - results: CalcResults = None, - properties: Collection[str] = (), - invalidate_calc: bool = True, - ) -> None: - """ - Remove NaN and inf values from results and calc.results dictionaries. - - Parameters - ---------- - results : CalcResults - Dictionary of calculated results. Default is {}. - properties : Collection[str] - Physical properties requested to be calculated. Default is (). - invalidate_calc : bool - Remove calculator results if True. When True Atoms object loses - its property methods and true values are in info and arrays. - Default is True. - """ - results = results if results else {} - - if isinstance(self.struct, list): - for image in self.struct: - self._remove_invalid_props(image, results, properties) - if invalidate_calc: - image.calc.results = {} - else: - self._remove_invalid_props(self.struct, results, properties) - if invalidate_calc: - self.struct.calc.results = {} - def run( self, - properties: MaybeSequence[str] = (), + properties: MaybeSequence[Properties] = (), write_results: bool = False, - write_kwargs: Optional[ASEWriteArgs] = None, + write_kwargs: Optional[OutputKwargs] = None, ) -> CalcResults: """ Run single point calculations. @@ -372,7 +292,7 @@ def run( "forces", and "stress" will be returned. write_results : bool True to write out structure with results of calculations. Default is False. - write_kwargs : Optional[ASEWriteArgs], + write_kwargs : Optional[OutputKwargs], Keyword arguments to pass to ase.io.write if saving structure with results of calculations. Default is {}. @@ -386,11 +306,15 @@ def run( properties = [properties] for prop in properties: - if prop not in ["energy", "forces", "stress"]: + if prop not in get_args(Properties): raise NotImplementedError( f"Property '{prop}' cannot currently be calculated." ) + # If none specified, get all valid properties + if len(properties) == 0: + properties = get_args(Properties) + write_kwargs = write_kwargs if write_kwargs else {} write_kwargs.setdefault( @@ -409,15 +333,16 @@ def run( if "stress" in properties or len(properties) == 0: results["stress"] = self._get_stress() - # Remove meaningless values from results e.g. stress for non-periodic systems - self._clean_results(results, properties=properties) - if self.logger: self.tracker.stop_task() self.tracker.stop() self.logger.info("Single point calculation complete") - if write_results: - write(images=self.struct, **write_kwargs) + output_structs( + self.struct, + write_results=write_results, + properties=properties, + write_kwargs=write_kwargs, + ) return results diff --git a/janus_core/cli/eos.py b/janus_core/cli/eos.py index bbc69a57..9407c03f 100644 --- a/janus_core/cli/eos.py +++ b/janus_core/cli/eos.py @@ -117,7 +117,7 @@ def eos( Other keyword arguments to pass to geometry optimizer. Default is {}. write_structures : bool True to write out all genereated structures. Default is False. - write_kwargs : Optional[ASEWriteArgs], + write_kwargs : Optional[dict[str, Any]], Keyword arguments to pass to ase.io.write to save generated structures. Default is {}. arch : Optional[str] diff --git a/janus_core/cli/geomopt.py b/janus_core/cli/geomopt.py index 869ad2c8..51fb1ce7 100644 --- a/janus_core/cli/geomopt.py +++ b/janus_core/cli/geomopt.py @@ -197,7 +197,7 @@ def geomopt( Keyword arguments to pass to the selected calculator. Default is {}. minimize_kwargs : Optional[dict[str, Any]] Other keyword arguments to pass to geometry optimizer. Default is {}. - write_kwargs : Optional[ASEWriteArgs] + write_kwargs : Optional[dict[str, Any]] Keyword arguments to pass to ase.io.write when saving optimized structure. Default is {}. log : Optional[Path] diff --git a/janus_core/cli/md.py b/janus_core/cli/md.py index 7f8e83d7..76f838ef 100644 --- a/janus_core/cli/md.py +++ b/janus_core/cli/md.py @@ -20,6 +20,7 @@ ReadKwargs, StructPath, Summary, + WriteKwargs, ) from janus_core.cli.utils import ( check_config, @@ -184,6 +185,7 @@ def md( temp_time: Annotated[ float, Option(help="Time between heating steps, in fs.") ] = None, + write_kwargs: WriteKwargs = None, post_process_kwargs: PostProcessKwargs = None, log: LogPath = "md.log", seed: Annotated[ @@ -291,6 +293,9 @@ def md( temp_time : Optional[float] Time between heating steps, in fs. Default is None, which disables heating. + write_kwargs : Optional[dict[str, Any]], + Keyword arguments to pass to `output_structs` when saving trajectory and final + files. Default is {}. post_process_kwargs : Optional[PostProcessKwargs] Kwargs to pass to post-processing. log : Optional[Path] @@ -311,6 +316,7 @@ def md( calc_kwargs, minimize_kwargs, ensemble_kwargs, + write_kwargs, post_process_kwargs, ] = parse_typer_dicts( [ @@ -318,6 +324,7 @@ def md( calc_kwargs, minimize_kwargs, ensemble_kwargs, + write_kwargs, post_process_kwargs, ] ) @@ -374,6 +381,7 @@ def md( "temp_end": temp_end, "temp_step": temp_step, "temp_time": temp_time, + "write_kwargs": write_kwargs, "post_process_kwargs": post_process_kwargs, "log_kwargs": log_kwargs, "seed": seed, diff --git a/janus_core/helpers/janus_types.py b/janus_core/helpers/janus_types.py index 418bf290..ddf9703d 100644 --- a/janus_core/helpers/janus_types.py +++ b/janus_core/helpers/janus_types.py @@ -1,6 +1,6 @@ """Module containing types used in Janus-Core.""" -from collections.abc import Sequence +from collections.abc import Collection, Sequence from enum import Enum import logging from pathlib import Path, PurePath @@ -141,6 +141,16 @@ class CorrelationKwargs(TypedDict, total=True): Architectures = Literal["mace", "mace_mp", "mace_off", "m3gnet", "chgnet", "alignn"] Devices = Literal["cpu", "cuda", "mps", "xpu"] Ensembles = Literal["nph", "npt", "nve", "nvt", "nvt-nh"] +Properties = Literal["energy", "stress", "forces"] + + +class OutputKwargs(ASEWriteArgs, total=False): + """Main keyword arguments for `output_structs`.""" + + set_info: bool + write_results: bool + properties: Collection[Properties] + invalidate_calc: bool class LogLevel(Enum): # numpydoc ignore=PR01 diff --git a/janus_core/helpers/mlip_calculators.py b/janus_core/helpers/mlip_calculators.py index d9dc67fb..47c62a81 100644 --- a/janus_core/helpers/mlip_calculators.py +++ b/janus_core/helpers/mlip_calculators.py @@ -210,5 +210,6 @@ def choose_calculator( ) calculator.parameters["version"] = __version__ + calculator.parameters["arch"] = architecture return calculator diff --git a/janus_core/helpers/utils.py b/janus_core/helpers/utils.py index 088aab64..8f1f14cc 100644 --- a/janus_core/helpers/utils.py +++ b/janus_core/helpers/utils.py @@ -1,13 +1,20 @@ """Utility functions for janus_core.""" from abc import ABC +from collections.abc import Collection from pathlib import Path -from typing import Optional +from typing import Optional, get_args from ase import Atoms +from ase.io import write from spglib import get_spacegroup -from janus_core.helpers.janus_types import PathLike +from janus_core.helpers.janus_types import ( + ASEWriteArgs, + MaybeSequence, + PathLike, + Properties, +) class FileNameMixin(ABC): # pylint: disable=too-few-public-methods @@ -229,3 +236,97 @@ def dict_remove_hyphens(dictionary: dict) -> dict: if isinstance(value, dict): dictionary[key] = dict_remove_hyphens(value) return {k.replace("-", "_"): v for k, v in dictionary.items()} + + +def results_to_info( + struct: Atoms, + *, + properties: Collection[Properties] = (), + invalidate_calc: bool = False, +) -> None: + """ + Copy or move MLIP calculated results to Atoms.info dict. + + Parameters + ---------- + struct : Atoms + Atoms object to copy or move calculated results to info dict. + properties : Collection[Properties] + Properties to copy from results to info dict. Default is (). + invalidate_calc : bool + Whether to remove all calculator results after copying properties to info dict. + Default is False. + """ + if not properties: + properties = get_args(Properties) + + if struct.calc: + # Set default architecture from calculator name + arch = struct.calc.parameters["arch"] + struct.info["arch"] = arch + + for key in properties & struct.calc.results.keys(): + tag = f"{arch}_{key}" + value = struct.calc.results[key] + if key == "forces": + struct.arrays[tag] = value + else: + struct.info[tag] = value + + # Remove all calculator results + if invalidate_calc: + struct.calc.results = {} + + +def output_structs( + images: MaybeSequence[Atoms], + *, + set_info: bool = True, + write_results: bool = False, + properties: Collection[Properties] = (), + invalidate_calc: bool = False, + write_kwargs: Optional[ASEWriteArgs] = None, +) -> None: + """ + Copy or move calculated results to Atoms.info dict and/or write structures to file. + + Parameters + ---------- + images : MaybeSequence[Atoms] + Atoms object or a list of Atoms objects to interact with. + set_info : bool + True to set info dict from calculated results. Default is True. + write_results : bool + True to write out structure with results of calculations. Default is False. + properties : Collection[Properties] + Properties to copy from calculated results to info dict. Default is (). + invalidate_calc : bool + Whether to remove all calculator results after copying properties to info dict. + Default is False. + write_kwargs : Optional[ASEWriteArgs] + Keyword arguments passed to ase.io.write. Default is {}. + """ + # Separate kwargs for output_structs from kwargs for ase.io.write + # This assumes values passed via kwargs have priority over passed parameters + write_kwargs = write_kwargs if write_kwargs else {} + set_info = write_kwargs.pop("set_info", set_info) + properties = write_kwargs.pop("properties", properties) + invalidate_calc = write_kwargs.pop("invalidate_calc", invalidate_calc) + + if isinstance(images, Atoms): + images = (images,) + + if set_info: + for image in images: + results_to_info( + image, properties=properties, invalidate_calc=invalidate_calc + ) + else: + # Label architecture even if not copying results to info + for image in images: + if image.calc: + image.info["arch"] = image.calc.parameters["arch"] + + if write_results: + write_kwargs.setdefault("write_results", not invalidate_calc) + write(images=images, **write_kwargs) diff --git a/tests/test_correlator.py b/tests/test_correlator.py index 7bc338ab..094b883f 100644 --- a/tests/test_correlator.py +++ b/tests/test_correlator.py @@ -72,7 +72,6 @@ def test_md_correlations(tmp_path): """Test correlations as part of MD cycle.""" file_prefix = tmp_path / "Cl4Na4-nve-T300.0" traj_path = tmp_path / "Cl4Na4-nve-T300.0-traj.extxyz" - stats_path = tmp_path / "Cl4Na4-nve-T300.0-stats.dat" cor_path = tmp_path / "Cl4Na4-nve-T300.0-cor.dat" single_point = SinglePoint( @@ -119,38 +118,34 @@ def user_observable_a(atoms: Atoms, kappa, **kwargs) -> float: "update_frequency": 1, }, ], + write_kwargs={"invalidate_calc": False}, ) - - try: - nve.run() - pxy = [ - atom.get_stress(include_ideal_gas=True, voigt=False).flatten()[1] / GPa - for atom in read(traj_path, index=":") - ] - assert cor_path.exists() - with open(cor_path, encoding="utf8") as in_file: - cor = load(in_file, Loader=Loader) - assert len(cor) == 2 - assert "user_correlation" in cor - assert "stress_xy_auto_cor" in cor - - stress_cor = cor["stress_xy_auto_cor"] - value, lags = stress_cor["value"], stress_cor["lags"] - assert len(value) == len(lags) == 11 - - direct = correlate(pxy, pxy, fft=False) - # input data differs due to i/o, error is expected 1e-5 - assert direct == approx(value, rel=1e-5) - - user_cor = cor["user_correlation"] - value, lags = user_cor["value"], stress_cor["lags"] - assert len(value) == len(lags) == 11 - - direct = correlate([v * 4.0 for v in pxy], pxy, fft=False) - # input data differs due to i/o, error is expected 1e-5 - assert direct == approx(value, rel=1e-5) - - finally: - traj_path.unlink(missing_ok=True) - stats_path.unlink(missing_ok=True) - cor_path.unlink(missing_ok=True) + nve.run() + + pxy = [ + atom.get_stress(include_ideal_gas=True, voigt=False).flatten()[1] / GPa + for atom in read(traj_path, index=":") + ] + + assert cor_path.exists() + with open(cor_path, encoding="utf8") as in_file: + cor = load(in_file, Loader=Loader) + assert len(cor) == 2 + assert "user_correlation" in cor + assert "stress_xy_auto_cor" in cor + + stress_cor = cor["stress_xy_auto_cor"] + value, lags = stress_cor["value"], stress_cor["lags"] + assert len(value) == len(lags) == 11 + + direct = correlate(pxy, pxy, fft=False) + # input data differs due to i/o, error is expected 1e-5 + assert direct == approx(value, rel=1e-5) + + user_cor = cor["user_correlation"] + value, lags = user_cor["value"], stress_cor["lags"] + assert len(value) == len(lags) == 11 + + direct = correlate([v * 4.0 for v in pxy], pxy, fft=False) + # input data differs due to i/o, error is expected 1e-5 + assert direct == approx(value, rel=1e-5) diff --git a/tests/test_geom_opt.py b/tests/test_geom_opt.py index 4b239f32..ec3f304e 100644 --- a/tests/test_geom_opt.py +++ b/tests/test_geom_opt.py @@ -66,7 +66,7 @@ def test_saving_struct(tmp_path): ) opt_struct = read(results_path) - assert opt_struct.get_potential_energy() < init_energy + assert opt_struct.info["mace_energy"] < init_energy def test_saving_traj(tmp_path): diff --git a/tests/test_geomopt_cli.py b/tests/test_geomopt_cli.py index 18702580..0634aa6d 100644 --- a/tests/test_geomopt_cli.py +++ b/tests/test_geomopt_cli.py @@ -104,7 +104,7 @@ def test_traj(tmp_path): ) assert result.exit_code == 0 atoms = read(traj_path) - assert "forces" in atoms.calc.results + assert "mace_mp_forces" in atoms.arrays def test_fully_opt(tmp_path): @@ -298,7 +298,7 @@ def test_restart(tmp_path): ) assert result.exit_code == 0 atoms = read(results_path) - intermediate_energy = atoms.get_potential_energy() + intermediate_energy = atoms.info["mace_mp_energy"] result = runner.invoke( app, @@ -320,7 +320,7 @@ def test_restart(tmp_path): ) assert result.exit_code == 0 atoms = read(results_path) - final_energy = atoms.get_potential_energy() + final_energy = atoms.info["mace_mp_energy"] assert final_energy < intermediate_energy diff --git a/tests/test_md_cli.py b/tests/test_md_cli.py index 6c826564..fb32cfe9 100644 --- a/tests/test_md_cli.py +++ b/tests/test_md_cli.py @@ -503,3 +503,57 @@ def test_final_name(tmp_path): assert traj_path.exists() assert stats_path.exists() assert final_path.exists() + + +def test_write_kwargs(tmp_path): + """Test passing write-kwargs.""" + struct_path = DATA_PATH / "NaCl.cif" + file_prefix = tmp_path / "md" + log_path = tmp_path / "md.log" + summary_path = tmp_path / "summary.yml" + final_path = tmp_path / "md-final.extxyz" + traj_path = tmp_path / "md-traj.extxyz" + write_kwargs = ( + "{'invalidate_calc': False, 'columns': ['symbols', 'positions', 'masses']}" + ) + + result = runner.invoke( + app, + [ + "md", + "--ensemble", + "npt", + "--struct", + struct_path, + "--file-prefix", + file_prefix, + "--steps", + 2, + "--write-kwargs", + write_kwargs, + "--traj-every", + 1, + "--log", + log_path, + "--summary", + summary_path, + ], + ) + + assert result.exit_code == 0 + assert final_path.exists() + assert traj_path.exists() + final_atoms = read(final_path) + traj = read(traj_path, index=":") + + # Check columns has been set + assert not final_atoms.has("momenta") + assert not traj[0].has("momenta") + + # Check calculated results have been saved + assert "energy" in final_atoms.calc.results + assert "energy" in traj[0].calc.results + + # Check labelled info has been set + assert "mace_mp_energy" in final_atoms.info + assert "mace_mp_energy" in traj[0].info diff --git a/tests/test_single_point.py b/tests/test_single_point.py index 5514dd5c..6446e496 100644 --- a/tests/test_single_point.py +++ b/tests/test_single_point.py @@ -252,6 +252,22 @@ def test_no_atoms_or_path(): ) +def test_invalidate_calc(): + """Test setting invalidate_calc via write_kwargs.""" + struct_path = DATA_PATH / "NaCl.cif" + single_point = SinglePoint( + struct_path=struct_path, + architecture="mace", + calc_kwargs={"model": MODEL_PATH}, + ) + + single_point.run(write_kwargs={"invalidate_calc": False}) + assert "energy" in single_point.struct.calc.results + + single_point.run(write_kwargs={"invalidate_calc": True}) + assert "energy" not in single_point.struct.calc.results + + test_mlips_data = [ ("m3gnet", "cpu", -26.729949951171875), ("chgnet", "cpu", -29.331436157226562), diff --git a/tests/test_singlepoint_cli.py b/tests/test_singlepoint_cli.py index bf01fee7..bf75e104 100644 --- a/tests/test_singlepoint_cli.py +++ b/tests/test_singlepoint_cli.py @@ -277,3 +277,27 @@ def test_invalid_config(): ) assert result.exit_code == 1 assert isinstance(result.exception, ValueError) + + +def test_write_kwargs(tmp_path): + """Test setting invalidate_calc and write_results via write_kwargs.""" + results_path = tmp_path / "NaCl-results.extxyz" + + result = runner.invoke( + app, + [ + "singlepoint", + "--struct", + DATA_PATH / "NaCl.cif", + "--write-kwargs", + "{'invalidate_calc': False}", + "--out", + results_path, + ], + ) + assert result.exit_code == 0 + atoms = read(results_path) + assert "mace_mp_energy" in atoms.info + assert "mace_mp_forces" in atoms.arrays + assert "energy" in atoms.calc.results + assert "forces" in atoms.calc.results diff --git a/tests/test_utils.py b/tests/test_utils.py index 74ef53ff..7dbe5a22 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,22 @@ """Test utility functions.""" from pathlib import Path +from typing import get_args -from janus_core.helpers.utils import dict_paths_to_strs, dict_remove_hyphens +from ase import Atoms +from ase.io import read +import pytest + +from janus_core.helpers.janus_types import Properties +from janus_core.helpers.mlip_calculators import choose_calculator +from janus_core.helpers.utils import ( + dict_paths_to_strs, + dict_remove_hyphens, + output_structs, +) + +DATA_PATH = Path(__file__).parent / "data/NaCl.cif" +MODEL_PATH = Path(__file__).parent / "models/mace_mp_small.model" def test_dict_paths_to_strs(): @@ -46,3 +60,81 @@ def test_dict_remove_hyphens(): assert dictionary["key_2"]["key_4"] == 4 assert dictionary["key_2"]["key_5"] == 5.0 assert dictionary["key_2"]["key6"]["key_7"] == "value7" + + +@pytest.mark.parametrize("arch", ["mace_mp", "m3gnet", "chgnet"]) +@pytest.mark.parametrize("write_results", [True, False]) +@pytest.mark.parametrize("properties", [None, ["energy"], ["energy", "forces"]]) +@pytest.mark.parametrize("invalidate_calc", [True, False]) +@pytest.mark.parametrize( + "write_kwargs", [{}, {"write_results": False}, {"set_info": False}] +) +def test_output_structs( + arch, write_results, properties, invalidate_calc, write_kwargs, tmp_path +): + """Test output_structs copies/moves results to Atoms.info and writes files.""" + struct = read(DATA_PATH) + struct.calc = choose_calculator(architecture=arch) + + if not properties: + results_keys = set(get_args(Properties)) + else: + results_keys = set(properties) + label_keys = {f"{arch}_{key}" for key in results_keys} + + write_kwargs = {} + output_file = tmp_path / "output.extxyz" + if write_results: + write_kwargs["filename"] = output_file + + # Use calculator + struct.get_potential_energy() + struct.get_stress() + + # Check all expected keys are in results + assert results_keys <= struct.calc.results.keys() + + # Check results and MLIP-labelled keys are not in info or arrays + assert not results_keys & struct.info.keys() + assert not results_keys & struct.arrays.keys() + assert not label_keys & struct.info.keys() + assert not label_keys & struct.arrays.keys() + + output_structs( + struct, + write_results=write_results, + properties=properties, + invalidate_calc=invalidate_calc, + write_kwargs=write_kwargs, + ) + + # Check results keys depend on invalidate_calc + if invalidate_calc: + assert not results_keys & struct.calc.results.keys() + else: + assert results_keys <= struct.calc.results.keys() + + # Check labelled keys added to info and arrays + if "set_info" not in write_kwargs or write_kwargs["set_info"]: + assert label_keys <= struct.info.keys() | struct.arrays.keys() + assert struct.info["arch"] == arch + + # Check file written correctly if write_results + if write_results: + assert output_file.exists() + atoms = read(output_file) + assert isinstance(atoms, Atoms) + + # Check labelled info and arrays was written and can be read back in + if "set_info" not in write_kwargs or write_kwargs["set_info"]: + assert label_keys <= atoms.info.keys() | atoms.arrays.keys() + assert atoms.info["arch"] == arch + + # Check calculator results depend on invalidate_calc + if invalidate_calc: + assert atoms.calc is None + elif "write_results" not in write_kwargs or write_kwargs["write_results"]: + assert results_keys <= atoms.calc.results.keys() + + else: + assert not output_file.exists()