From e4491a78c57641c2dd0ec08ffdcc43153b6c9ce7 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Mon, 30 Jan 2023 11:10:05 +0100 Subject: [PATCH] Add a `AnalayisCollection` class --- package/CHANGELOG | 5 +- package/MDAnalysis/analysis/base.py | 183 +++++++++++++++--- .../MDAnalysisTests/analysis/test_base.py | 74 ++++++- 3 files changed, 230 insertions(+), 32 deletions(-) diff --git a/package/CHANGELOG b/package/CHANGELOG index babb1f6064a..2c975780c72 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -13,7 +13,8 @@ The rules for this file: * release numbers follow "Semantic Versioning" http://semver.org ------------------------------------------------------------------------------ -??/??/?? IAlibay, pgbarletta, mglagolev, hmacdope, manuel.nuno.melo +??/??/?? IAlibay, pgbarletta, mglagolev, hmacdope, manuel.nuno.melo, + PicoCentauri * 2.5.0 @@ -24,6 +25,8 @@ Fixes (Issue #3336) Enhancements + * Add a `AnalayisCollection` class to perform multiple analysis on the same + trajectory (#3569, PR #4017). * Add pickling support for Atom, Residue, Segment, ResidueGroup and SegmentGroup. (PR #3953) diff --git a/package/MDAnalysis/analysis/base.py b/package/MDAnalysis/analysis/base.py index 266ae68a02e..a0292deb526 100644 --- a/package/MDAnalysis/analysis/base.py +++ b/package/MDAnalysis/analysis/base.py @@ -220,7 +220,152 @@ def __setstate__(self, state): self.data = state -class AnalysisBase(object): +class AnalysisCollection(object): + """ + Class for running a collection of analysis classes on a single trajectory. + + Running a collection of analysis with ``AnalysisCollection`` can result in + a speedup compared to running the individual object since here we only + perform the trajectory looping once. + + The class assumes that each analysis is a child of + :class:`MDAnalysis.analysis.base.AnalysisBase`. Additionally, + the trajectory of all ``analysis_objects`` must be same. + + By default it is ensured that all analyisis objects use the + *same original* timestep and not an altered one from a previous analysis + object. This behaviour can be changed with the ``reset_timestep`` parameter + of the :meth:`MDAnalysis.analysis.base.AnalysisCollection.run` method. + + Parameters + ---------- + *analysis_objects : tuple + List of analysis classes to run on the same trajectory. + + Raises + ------ + AttributeError + If the provided ``analysis_objects`` do not have the same trajectory. + AttributeError + If an ``analysis_object`` is not a child of + :class:`MDAnalysis.analysis.base.AnalysisBase`. + + Example + ------- + .. code-block:: python + + import MDAnalysis as mda + from MDAnalysis.analysis.rdf import InterRDF + from MDAnalysis.analysis.base import AnalysisCollection + from MDAnalysisTests.datafiles import TPR, XTC + + u = mda.Universe(TPR, XTC) + + # Select atoms + O = u.select_atoms('name O') + H = u.select_atoms('name H') + + # Create individual analysis objects + rdf_OO = InterRDF(O, O) + rdf_OH = InterRDF(O, H) + + # Create collection for common trajectory + collection = AnalysisCollection(rdf_OO, rdf_OH) + + # Run the collected analysis + collection.run(start=0, end=100, step=10) + + # Results are stored in indivial objects + print(rdf_OO.results) + print(rdf_OH.results) + + .. versionadded:: 2.5.0 + """ + def __init__(self, *analysis_objects): + for analysis_object in analysis_objects: + if analysis_objects[0]._trajectory != analysis_object._trajectory: + raise ValueError("`analysis_objects` do not have the same " + "trajectory.") + if not isinstance(analysis_object, AnalysisBase): + raise AttributeError(f"Analysis object {analysis_object} is " + "not a child of `AnalysisBase`.") + + self._analysis_objects = analysis_objects + + def run(self, start=None, stop=None, step=None, frames=None, + verbose=None, reset_timestep=True): + """Perform the calculation + + Parameters + ---------- + start : int, optional + start frame of analysis + stop : int, optional + stop frame of analysis + step : int, optional + number of frames to skip between each analysed frame + frames : array_like, optional + array of integers or booleans to slice trajectory; ``frames`` can + only be used *instead* of ``start``, ``stop``, and ``step``. Setting + *both* `frames` and at least one of ``start``, ``stop``, ``step`` to a + non-default value will raise a :exc:``ValueError``. + verbose : bool, optional + Turn on verbosity + reset_timestep : bool + Reset the timestep object after for each ``analysis_object``. + Setting this to ``False`` can be useful if an ``analysis_object`` + is performing a trajectory manipulation which is also useful for the + subsequent ``analysis_objects`` e.g. unwrapping of molecules. + """ + + # Ensure compatibility with API of version 0.15.0 + if not hasattr(self, "_analysis_objects"): + self._analysis_objects = (self, ) + + logger.info("Choosing frames to analyze") + # if verbose unchanged, use class default + verbose = getattr(self, '_verbose', + False) if verbose is None else verbose + + logger.info("Starting preparation") + + for analysis_object in self._analysis_objects: + analysis_object._setup_frames(analysis_object._trajectory, + start=start, + stop=stop, + step=step, + frames=frames) + analysis_object._prepare() + + logger.info("Starting analysis loop over" + f"{self._analysis_objects[0].n_frames} trajectory frames") + + for i, ts in enumerate(ProgressBar( + self._analysis_objects[0]._sliced_trajectory, + verbose=verbose)): + + if reset_timestep: + ts_original = ts.copy() + + for analysis_object in self._analysis_objects: + analysis_object._frame_index = i + analysis_object._ts = ts + analysis_object.frames[i] = ts.frame + analysis_object.times[i] = ts.time + analysis_object._single_frame() + + if reset_timestep: + ts = ts_original + + logger.info("Finishing up") + + for analysis_object in self._analysis_objects: + analysis_object._conclude() + + return self + + +class AnalysisBase(AnalysisCollection): r"""Base class for defining multi-frame analysis The class is designed as a template for creating multi-frame analyses. @@ -316,6 +461,7 @@ def __init__(self, trajectory, verbose=False, **kwargs): self._trajectory = trajectory self._verbose = verbose self.results = Results() + super(AnalysisBase, self).__init__(self) def _setup_frames(self, trajectory, start=None, stop=None, step=None, frames=None): @@ -402,10 +548,10 @@ def run(self, start=None, stop=None, step=None, frames=None, step : int, optional number of frames to skip between each analysed frame frames : array_like, optional - array of integers or booleans to slice trajectory; `frames` can - only be used *instead* of `start`, `stop`, and `step`. Setting - *both* `frames` and at least one of `start`, `stop`, `step` to a - non-default value will raise a :exc:`ValueError`. + array of integers or booleans to slice trajectory; ``frames`` can + only be used *instead* of ``start``, ``stop``, and ``step``. Setting + *both* `frames` and at least one of ``start``, ``stop``, ``step`` to a + non-default value will raise a :exc:``ValueError``. .. versionadded:: 2.2.0 @@ -418,28 +564,11 @@ def run(self, start=None, stop=None, step=None, frames=None, frame indices in the `frames` keyword argument. """ - logger.info("Choosing frames to analyze") - # if verbose unchanged, use class default - verbose = getattr(self, '_verbose', - False) if verbose is None else verbose - - self._setup_frames(self._trajectory, start=start, stop=stop, - step=step, frames=frames) - logger.info("Starting preparation") - self._prepare() - logger.info("Starting analysis loop over %d trajectory frames", - self.n_frames) - for i, ts in enumerate(ProgressBar( - self._sliced_trajectory, - verbose=verbose)): - self._frame_index = i - self._ts = ts - self.frames[i] = ts.frame - self.times[i] = ts.time - self._single_frame() - logger.info("Finishing up") - self._conclude() - return self + return super(AnalysisBase, self).run(start=start, + stop=stop, + step=step, + frames=frames, + verbose=verbose) class AnalysisFromFunction(AnalysisBase): diff --git a/testsuite/MDAnalysisTests/analysis/test_base.py b/testsuite/MDAnalysisTests/analysis/test_base.py index 09b7b158f0a..acb8ac33f25 100644 --- a/testsuite/MDAnalysisTests/analysis/test_base.py +++ b/testsuite/MDAnalysisTests/analysis/test_base.py @@ -27,10 +27,11 @@ import numpy as np -from numpy.testing import assert_equal, assert_almost_equal +from numpy.testing import assert_equal, assert_allclose import MDAnalysis as mda from MDAnalysis.analysis import base +from MDAnalysis.analysis.rdf import InterRDF from MDAnalysisTests.datafiles import PSF, DCD, TPR, XTC from MDAnalysisTests.util import no_deprecated_call @@ -147,6 +148,71 @@ def test_different_instances(self, results): assert new_results.data is not results.data +class TestAnalysisCollection: + @pytest.fixture + def universe(self): + return mda.Universe(TPR, XTC) + + def test_run(self, universe): + O = universe.select_atoms('name O') + H = universe.select_atoms('name H') + + rdf_OO = InterRDF(O, O) + rdf_OH = InterRDF(O, H) + + collection = base.AnalysisCollection(rdf_OO, rdf_OH) + collection.run(start=0, stop=100, step=10) + + assert rdf_OO.results is not None + assert rdf_OH.results is not None + + @pytest.mark.parametrize("reset_timestep", [True, False]) + def test_trajectory_manipulation(self, universe, reset_timestep): + + class CustomAnalysis(base.AnalysisBase): + """Custom class that is shifting positions in every step by 10.""" + def __init__(self, trajectory): + self._trajectory = trajectory + + def _prepare(self): + pass + + def _single_frame(self): + self._ts.positions += 10 + self.ref_pos = self._ts.positions.copy()[0,0] + + ana_1 = CustomAnalysis(universe.trajectory) + ana_2 = CustomAnalysis(universe.trajectory) + + collection = base.AnalysisCollection(ana_1, ana_2) + + collection.run(frames=[0], reset_timestep=reset_timestep) + + if reset_timestep: + assert ana_2.ref_pos == ana_1.ref_pos + else: + assert_allclose(ana_2.ref_pos, ana_1.ref_pos + 10.) + + def test_no_trajectory_manipulation(self): + pass + + def test_inconsistent_trajectory(self, universe): + v = mda.Universe(TPR, XTC) + + with pytest.raises(ValueError, match="`analysis_objects` do not have the same"): + base.AnalysisCollection(InterRDF(universe.atoms, universe.atoms), + InterRDF(v.atoms, v.atoms)) + + def test_no_base_child(self, universe): + class CustomAnalysis: + def __init__(self, trajectory): + self._trajectory = trajectory + + # Create collection for common trajectory loop with inconsistent trajectory + with pytest.raises(AttributeError, match="not a child of `AnalysisBa"): + base.AnalysisCollection(CustomAnalysis(universe.trajectory)) + + class FrameAnalysis(base.AnalysisBase): """Just grabs frame numbers of frames it goes over""" @@ -194,7 +260,7 @@ def test_start_stop_step(u, run_kwargs, frames): assert an.n_frames == len(frames) assert_equal(an.found_frames, frames) assert_equal(an.frames, frames, err_msg=FRAMES_ERR) - assert_almost_equal(an.times, frames+1, decimal=4, err_msg=TIMES_ERR) + assert_allclose(an.times, frames+1, rtol=1e-4, err_msg=TIMES_ERR) @pytest.mark.parametrize('run_kwargs, frames', [ @@ -251,7 +317,7 @@ def test_frames_times(): assert an.n_frames == len(frames) assert_equal(an.found_frames, frames) assert_equal(an.frames, frames, err_msg=FRAMES_ERR) - assert_almost_equal(an.times, frames*100, decimal=4, err_msg=TIMES_ERR) + assert_allclose(an.times, frames*100, rtol=1e-4, err_msg=TIMES_ERR) def test_verbose(u): @@ -366,7 +432,7 @@ def test_AnalysisFromFunction_args_content(u): ans = base.AnalysisFromFunction(mass_xyz, protein, another, masses) assert len(ans.args) == 3 result = np.sum(ans.run().results.timeseries) - assert_almost_equal(result, -317054.67757345125, decimal=6) + assert_allclose(result, -317054.67757345125, rtol=1e-6) assert (ans.args[0] is protein) and (ans.args[1] is another) assert ans._trajectory is protein.universe.trajectory