From 0dac46937c26a078b8b92864321919d8ed2ccac7 Mon Sep 17 00:00:00 2001 From: Dimitri RODARIE Date: Tue, 29 Oct 2024 14:45:26 +0100 Subject: [PATCH] fix: mpi comm (#896) * fix: forward mpi comm to Scaffold and Storage object * fix: forward mpi comm to JobPool and Job * test: add unittests --- bsb/config/_attrs.py | 3 +-- bsb/config/_config.py | 2 +- bsb/core.py | 30 ++++++++++++++++++-------- bsb/morphologies/selector.py | 5 ++--- bsb/services/mpi.py | 9 ++++++-- bsb/services/pool.py | 15 +++++++------ bsb/storage/__init__.py | 32 ++++++++++++++++++++++------ tests/test_core.py | 41 ++++++++++++++++++++++++++++++++++-- tests/test_jobs.py | 7 +++++- 9 files changed, 110 insertions(+), 34 deletions(-) diff --git a/bsb/config/_attrs.py b/bsb/config/_attrs.py index 9aec9e56f..b86e2f0d7 100644 --- a/bsb/config/_attrs.py +++ b/bsb/config/_attrs.py @@ -14,7 +14,6 @@ NoReferenceAttributeSignal, RequirementError, ) -from ..services import MPI from ._compile import _wrap_reserved from ._hooks import run_hook from ._make import ( @@ -396,7 +395,7 @@ def _boot_nodes(top_node, scaffold): except Exception as e: errr.wrap(BootError, e, prepend=f"Failed to boot {node}:") # fixme: why is this here? Will deadlock in case of BootError on specific node only. - MPI.barrier() + scaffold._comm.barrier() def _unset_nodes(top_node): diff --git a/bsb/config/_config.py b/bsb/config/_config.py index 599b806e9..523af696f 100644 --- a/bsb/config/_config.py +++ b/bsb/config/_config.py @@ -182,7 +182,7 @@ def _update_storage_node(self, storage): if self.storage.engine != storage.format: self.storage.engine = storage.format if self.storage.root != storage.root: - self.storage.root = storage.root + self.storage._root = storage.root def __str__(self): return str(self.__tree__()) diff --git a/bsb/core.py b/bsb/core.py index 891f62ae4..05f7d643a 100644 --- a/bsb/core.py +++ b/bsb/core.py @@ -17,8 +17,9 @@ from .placement import PlacementStrategy from .profiling import meter from .reporting import report -from .services import MPI, JobPool +from .services import JobPool from .services._pool_listeners import NonTTYTerminalListener, TTYTerminalListener +from .services.mpi import MPIService from .services.pool import Job, Workflow from .simulation import get_simulation_adapter from .storage import Storage, open_storage @@ -39,15 +40,17 @@ @meter() -def from_storage(root): +def from_storage(root, comm=None): """ Load :class:`.core.Scaffold` from a storage object. :param root: Root (usually path) pointing to the storage object. + :param mpi4py.MPI.Comm comm: MPI communicator that shares control + over the Storage. :returns: A network scaffold :rtype: :class:`Scaffold` """ - return open_storage(root).load() + return open_storage(root, comm).load() _cfg_props = ( @@ -128,6 +131,8 @@ def __init__(self, config=None, storage=None, clear=False, comm=None): :type storage: :class:`~.storage.Storage` :param clear: Start with a new network, clearing any previously stored information :type clear: bool + :param comm: MPI communicator that shares control over the Storage. + :type comm: mpi4py.MPI.Comm :returns: A network object :rtype: :class:`~.core.Scaffold` """ @@ -137,7 +142,7 @@ def __init__(self, config=None, storage=None, clear=False, comm=None): ) self._configuration = None self._storage = None - self._comm = comm or MPI + self._comm = MPIService(comm) self._bootstrap(config, storage, clear=clear) def __contains__(self, component): @@ -151,10 +156,10 @@ def __repr__(self): return f"'{file}' with {cells_placed} cell types, and {n_types} connection_types" def is_main_process(self) -> bool: - return not MPI.get_rank() + return not self._comm.get_rank() def is_worker_process(self) -> bool: - return bool(MPI.get_rank()) + return bool(self._comm.get_rank()) def _bootstrap(self, config, storage, clear=False): if config is None: @@ -174,7 +179,12 @@ def _bootstrap(self, config, storage, clear=False): if not storage: # No storage given, create one. report("Creating storage from config.", level=4) - storage = Storage(config.storage.engine, config.storage.root) + storage = Storage( + config.storage.engine, config.storage.root, self._comm.get_communicator() + ) + else: + # Override MPI comm of storage to match the scaffold's + storage._comm = self._comm if clear: # Storage given, but asked to clear it before use. storage.remove() @@ -183,14 +193,16 @@ def _bootstrap(self, config, storage, clear=False): self._configuration = config # Make sure the storage config node reflects the storage we are using config._update_storage_node(storage) - # Give the scaffold access to the unitialized storage object (for use during + # Give the scaffold access to the uninitialized storage object (for use during # config bootstrapping). self._storage = storage # First, the scaffold is passed to each config node, and their boot methods called self._configuration._bootstrap(self) - # Then, `storage` is initted for the scaffold, and `config` is stored (happens + # Then, `storage` is initialized for the scaffold, and `config` is stored (happens # inside the `storage` property). self.storage = storage + # Synchronize the JobPool static variable so that each core use the same ids. + JobPool._next_pool_id = self._comm.bcast(JobPool._next_pool_id, root=0) storage_cfg = _config_property("storage") for attr in _cfg_props: diff --git a/bsb/morphologies/selector.py b/bsb/morphologies/selector.py index 8954246eb..ceae8636a 100644 --- a/bsb/morphologies/selector.py +++ b/bsb/morphologies/selector.py @@ -13,7 +13,6 @@ from ..config import types from ..config._attrs import cfglist from ..exceptions import MissingMorphologyError, SelectorError -from ..services import MPI from . import Morphology if typing.TYPE_CHECKING: @@ -92,11 +91,11 @@ def __boot__(self): try: morphos = self._scrape_nm(self.names) except: - MPI.barrier() + self.scaffold._comm.barrier() raise for name, morpho in morphos.items(): self.scaffold.morphologies.save(name, morpho, overwrite=True) - MPI.barrier() + self.scaffold._comm.barrier() @classmethod def _swc_url(cls, archive, name): diff --git a/bsb/services/mpi.py b/bsb/services/mpi.py index d74924b5a..879a4b994 100644 --- a/bsb/services/mpi.py +++ b/bsb/services/mpi.py @@ -6,9 +6,14 @@ class MPIService: - def __init__(self): + """ + Interface for MPI Communication context. + This class will also emulate MPI Communication context in single node context. + """ + + def __init__(self, comm=None): self._mpi = MPIModule("mpi4py.MPI") - self._comm = self._mpi.COMM_WORLD + self._comm = comm or self._mpi.COMM_WORLD def get_communicator(self): return self._comm diff --git a/bsb/services/pool.py b/bsb/services/pool.py index dafb5b363..5d0e8dff9 100644 --- a/bsb/services/pool.py +++ b/bsb/services/pool.py @@ -61,7 +61,6 @@ class name, ``_name`` for the job name and ``_c`` for the chunk. These are used JobPoolError, JobSchedulingError, ) -from . import MPI from ._util import ErrorModule, MockModule if typing.TYPE_CHECKING: @@ -295,6 +294,7 @@ def __init__( cache_items=None, ): self.pool_id = pool.id + self._comm = pool._comm self._args = args self._kwargs = kwargs self._deps = set(deps or []) @@ -431,7 +431,7 @@ def _dep_completed(self, dep): else: # When all our dependencies have been discarded we can queue ourselves. Unless the # pool is serial, then the pool itself just runs all jobs in order. - if not self._deps and MPI.get_size() > 1: + if not self._deps and self._comm.get_size() > 1: # self._pool is set when the pool first tried to enqueue us, but we were still # waiting for deps, in the `_enqueue` method below. self._enqueue(self._pool) @@ -544,6 +544,7 @@ def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None): self._schedulers: list[concurrent.futures.Future] = [] self.id: int = None self._scaffold = scaffold + self._comm = scaffold._comm self._unhandled_errors = [] self._running_futures: list[concurrent.futures.Future] = [] self._mpipool: typing.Optional["MPIExecutor"] = None @@ -556,7 +557,7 @@ def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None): self._fail_fast = fail_fast self._workflow = workflow self._cache_buffer = np.zeros(1000, dtype=np.uint64) - self._cache_window = MPI.window(self._cache_buffer) + self._cache_window = self._comm.window(self._cache_buffer) def __enter__(self): self._context = ExitStack() @@ -605,7 +606,7 @@ def jobs(self) -> list[Job]: @property def parallel(self): - return MPI.get_size() > 1 + return self._comm.get_size() > 1 @classmethod def get_owner(cls, id): @@ -620,7 +621,7 @@ def owner(self): return self.get_owner(self.id) def is_main(self): - return MPI.get_rank() == 0 + return self._comm.get_rank() == 0 def get_submissions_of(self, submitter): return [job for job in self._job_queue if job.submitter is submitter] @@ -746,7 +747,7 @@ def _execute_parallel(self): # master logic. # Check if we need to abort our process due to errors etc. - abort = MPI.bcast(None) + abort = self._comm.bcast(None) if abort: raise WorkflowError( "Unhandled exceptions during parallel execution.", @@ -816,7 +817,7 @@ def _execute_parallel(self): # Shut down our internal pool self._mpipool.shutdown(wait=False, cancel_futures=True) # Broadcast whether the worker nodes should raise an unhandled error. - MPI.bcast(self._workers_raise_unhandled) + self._comm.bcast(self._workers_raise_unhandled) def _execute_serial(self): # Wait for jobs to finish scheduling diff --git a/bsb/storage/__init__.py b/bsb/storage/__init__.py index cced84b67..a3cfe6704 100644 --- a/bsb/storage/__init__.py +++ b/bsb/storage/__init__.py @@ -19,7 +19,7 @@ from .. import plugins from ..exceptions import UnknownStorageEngineError -from ..services import MPI +from ..services.mpi import MPIService if typing.TYPE_CHECKING: from .interfaces import ( @@ -89,7 +89,16 @@ def get_engines(): def create_engine(name, root, comm): - # Create an engine from the engine's Engine interface. + """ + Create an engine from the engine's Engine interface. + + :param str name: The name of the engine to create. + :param object root: An object that uniquely describes the storage, such as a filename + or path. The value to be provided depends on the engine. For the hdf5 engine + the filename has to be provided. + :param bsb.services.mpi.MPIService comm: MPI communicator that shares control over the + Engine interface. + """ return get_engine_support(name)["Engine"](root, comm) @@ -119,7 +128,7 @@ def __getattr__(self, attr): class Storage: """ - Factory class that produces all of the features and shims the functionality of the + Factory class that produces all the features and shims the functionality of the underlying engine. """ @@ -143,7 +152,7 @@ def __init__(self, engine, root, comm=None, main=0, missing_ok=True): :type comm: mpi4py.MPI.Comm :param main: Rank of the MPI process that executes single-node tasks. """ - self._comm = comm or MPI + self._comm = MPIService(comm) self._engine = create_engine(engine, root, self._comm) self._features = [ fname for fname, supported in view_support()[engine].items() if supported @@ -244,7 +253,7 @@ def load(self): """ from ..core import Scaffold - return Scaffold(storage=self) + return Scaffold(storage=self, comm=self._comm.get_communicator()) def load_active_config(self): """ @@ -368,11 +377,20 @@ def get_chunk_stats(self): return self._engine.get_chunk_stats() -def open_storage(root): +def open_storage(root, comm=None): + """ + Load a Storage object from its root. + + :param root: Root (usually path) pointing to the storage object. + :param mpi4py.MPI.Comm comm: MPI communicator that shares control + over the Storage. + :returns: A network scaffold + :rtype: :class:`Storage` + """ engines = get_engines() for name, engine in engines.items(): if engine.peek_exists(root) and engine.recognizes(root): - return Storage(name, root, missing_ok=False) + return Storage(name, root, comm, missing_ok=False) else: for name, engine in engines.items(): if engine.peek_exists(root): diff --git a/tests/test_core.py b/tests/test_core.py index 11b724832..4744907f2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,8 +1,10 @@ +import os import unittest -from bsb_test import NetworkFixture, RandomStorageFixture +from bsb_test import NetworkFixture, RandomStorageFixture, skip_serial, timeout +from mpi4py import MPI -from bsb import Configuration, PlacementSet, core +from bsb import Configuration, PlacementSet, Scaffold, Storage, core, get_engine_node class TestCore( @@ -97,6 +99,41 @@ def test_diagrams(self): self.assertIn('cell2[label="cell2 (3 cell2)"]', storage_diagram) self.assertIn('cell1 -> cell2[label="a_to_b (9)"]', storage_diagram) + @skip_serial + @timeout(3) + def test_mpi_from_storage(self): + self.network.compile(clear=True) + world = MPI.COMM_WORLD + if world.Get_rank() != 1: + # we make rank 1 skip while the others would load the network + group = world.group.Excl([1]) + comm = world.Create_group(group) + core.from_storage(self.network.storage.root, comm) + + @skip_serial + @timeout(3) + def test_mpi_compile(self): + world = MPI.COMM_WORLD + if world.Get_rank() != 1: + # we make rank 1 skip while the others would load the network + group = world.group.Excl([1]) + comm = world.Create_group(group) + # Test compile with no storage + Scaffold( + Configuration.default( + storage=dict(engine="hdf5", root="test_network.hdf5") + ), + comm=comm, + ).compile(clear=True) + if world.Get_rank() == 0: + os.remove("test_network.hdf5") + # Test compile with external storage + s = Storage("hdf5", get_engine_node("hdf5")(engine="hdf5").root, comm=comm) + # self.cfg was modified when creating self.network but should update to match + # the new storage + Scaffold(self.cfg, storage=s, comm=comm).compile(clear=True) + s.remove() + class TestProfiling( RandomStorageFixture, NetworkFixture, unittest.TestCase, engine_name="hdf5" diff --git a/tests/test_jobs.py b/tests/test_jobs.py index f5c14e032..4472dd1be 100644 --- a/tests/test_jobs.py +++ b/tests/test_jobs.py @@ -10,6 +10,7 @@ NumpyTestCase, RandomStorageFixture, skip_parallel, + skip_serial, timeout, ) @@ -251,7 +252,7 @@ def test_placement_job(self): self.assertClose([[0, 0, 0]], ps.load_positions()) -@unittest.skipIf(MPI.get_size() < 2, "Skipped during serial testing.") +@skip_serial class TestParallelScheduler( RandomStorageFixture, NetworkFixture, unittest.TestCase, engine_name="hdf5" ): @@ -554,6 +555,7 @@ def place(self, chunk, indicators): self.network.placement.withcache.cache_something.cache_clear() self.id_cache = _cache_hash("{root}.placement.withcache.cache_something") + @timeout(3) def test_cache_registration(self): """Test that when a cache is hit, it is registered in the scaffold""" self.network.placement.withcache.place(None, None) @@ -562,6 +564,7 @@ def test_cache_registration(self): [*self.network._pool_cache.keys()], ) + @timeout(3) def test_method_detection(self): """Test that we can detect which jobs need which items""" self.assertEqual( @@ -569,6 +572,7 @@ def test_method_detection(self): get_node_cache_items(self.network.placement.withcache), ) + @timeout(3) def test_pool_required_cache(self): """Test that the pool knows which cache items are required""" with self.network.create_job_pool() as pool: @@ -589,6 +593,7 @@ def test_pool_required_cache(self): "bsb.services.pool.JobPool._read_required_cache_items", lambda self: mock_read_required_cache_items(self), ) + @timeout(3) def test_cache_survival(self): """Test that the required cache items survive until the jobs are done."""