Skip to content

Commit

Permalink
fix: mpi comm (#896)
Browse files Browse the repository at this point in the history
* fix: forward mpi comm to Scaffold and Storage object

* fix: forward mpi comm to JobPool and Job

* test: add unittests
  • Loading branch information
drodarie authored Oct 29, 2024
1 parent 96af428 commit 0dac469
Show file tree
Hide file tree
Showing 9 changed files with 110 additions and 34 deletions.
3 changes: 1 addition & 2 deletions bsb/config/_attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
NoReferenceAttributeSignal,
RequirementError,
)
from ..services import MPI
from ._compile import _wrap_reserved
from ._hooks import run_hook
from ._make import (
Expand Down Expand Up @@ -396,7 +395,7 @@ def _boot_nodes(top_node, scaffold):
except Exception as e:
errr.wrap(BootError, e, prepend=f"Failed to boot {node}:")
# fixme: why is this here? Will deadlock in case of BootError on specific node only.
MPI.barrier()
scaffold._comm.barrier()


def _unset_nodes(top_node):
Expand Down
2 changes: 1 addition & 1 deletion bsb/config/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _update_storage_node(self, storage):
if self.storage.engine != storage.format:
self.storage.engine = storage.format
if self.storage.root != storage.root:
self.storage.root = storage.root
self.storage._root = storage.root

def __str__(self):
return str(self.__tree__())
Expand Down
30 changes: 21 additions & 9 deletions bsb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
from .placement import PlacementStrategy
from .profiling import meter
from .reporting import report
from .services import MPI, JobPool
from .services import JobPool
from .services._pool_listeners import NonTTYTerminalListener, TTYTerminalListener
from .services.mpi import MPIService
from .services.pool import Job, Workflow
from .simulation import get_simulation_adapter
from .storage import Storage, open_storage
Expand All @@ -39,15 +40,17 @@


@meter()
def from_storage(root):
def from_storage(root, comm=None):
"""
Load :class:`.core.Scaffold` from a storage object.
:param root: Root (usually path) pointing to the storage object.
:param mpi4py.MPI.Comm comm: MPI communicator that shares control
over the Storage.
:returns: A network scaffold
:rtype: :class:`Scaffold`
"""
return open_storage(root).load()
return open_storage(root, comm).load()


_cfg_props = (
Expand Down Expand Up @@ -128,6 +131,8 @@ def __init__(self, config=None, storage=None, clear=False, comm=None):
:type storage: :class:`~.storage.Storage`
:param clear: Start with a new network, clearing any previously stored information
:type clear: bool
:param comm: MPI communicator that shares control over the Storage.
:type comm: mpi4py.MPI.Comm
:returns: A network object
:rtype: :class:`~.core.Scaffold`
"""
Expand All @@ -137,7 +142,7 @@ def __init__(self, config=None, storage=None, clear=False, comm=None):
)
self._configuration = None
self._storage = None
self._comm = comm or MPI
self._comm = MPIService(comm)
self._bootstrap(config, storage, clear=clear)

def __contains__(self, component):
Expand All @@ -151,10 +156,10 @@ def __repr__(self):
return f"'{file}' with {cells_placed} cell types, and {n_types} connection_types"

def is_main_process(self) -> bool:
return not MPI.get_rank()
return not self._comm.get_rank()

def is_worker_process(self) -> bool:
return bool(MPI.get_rank())
return bool(self._comm.get_rank())

def _bootstrap(self, config, storage, clear=False):
if config is None:
Expand All @@ -174,7 +179,12 @@ def _bootstrap(self, config, storage, clear=False):
if not storage:
# No storage given, create one.
report("Creating storage from config.", level=4)
storage = Storage(config.storage.engine, config.storage.root)
storage = Storage(
config.storage.engine, config.storage.root, self._comm.get_communicator()
)
else:
# Override MPI comm of storage to match the scaffold's
storage._comm = self._comm
if clear:
# Storage given, but asked to clear it before use.
storage.remove()
Expand All @@ -183,14 +193,16 @@ def _bootstrap(self, config, storage, clear=False):
self._configuration = config
# Make sure the storage config node reflects the storage we are using
config._update_storage_node(storage)
# Give the scaffold access to the unitialized storage object (for use during
# Give the scaffold access to the uninitialized storage object (for use during
# config bootstrapping).
self._storage = storage
# First, the scaffold is passed to each config node, and their boot methods called
self._configuration._bootstrap(self)
# Then, `storage` is initted for the scaffold, and `config` is stored (happens
# Then, `storage` is initialized for the scaffold, and `config` is stored (happens
# inside the `storage` property).
self.storage = storage
# Synchronize the JobPool static variable so that each core use the same ids.
JobPool._next_pool_id = self._comm.bcast(JobPool._next_pool_id, root=0)

storage_cfg = _config_property("storage")
for attr in _cfg_props:
Expand Down
5 changes: 2 additions & 3 deletions bsb/morphologies/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from ..config import types
from ..config._attrs import cfglist
from ..exceptions import MissingMorphologyError, SelectorError
from ..services import MPI
from . import Morphology

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -92,11 +91,11 @@ def __boot__(self):
try:
morphos = self._scrape_nm(self.names)
except:
MPI.barrier()
self.scaffold._comm.barrier()
raise
for name, morpho in morphos.items():
self.scaffold.morphologies.save(name, morpho, overwrite=True)
MPI.barrier()
self.scaffold._comm.barrier()

@classmethod
def _swc_url(cls, archive, name):
Expand Down
9 changes: 7 additions & 2 deletions bsb/services/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,14 @@


class MPIService:
def __init__(self):
"""
Interface for MPI Communication context.
This class will also emulate MPI Communication context in single node context.
"""

def __init__(self, comm=None):
self._mpi = MPIModule("mpi4py.MPI")
self._comm = self._mpi.COMM_WORLD
self._comm = comm or self._mpi.COMM_WORLD

def get_communicator(self):
return self._comm
Expand Down
15 changes: 8 additions & 7 deletions bsb/services/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ class name, ``_name`` for the job name and ``_c`` for the chunk. These are used
JobPoolError,
JobSchedulingError,
)
from . import MPI
from ._util import ErrorModule, MockModule

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -295,6 +294,7 @@ def __init__(
cache_items=None,
):
self.pool_id = pool.id
self._comm = pool._comm
self._args = args
self._kwargs = kwargs
self._deps = set(deps or [])
Expand Down Expand Up @@ -431,7 +431,7 @@ def _dep_completed(self, dep):
else:
# When all our dependencies have been discarded we can queue ourselves. Unless the
# pool is serial, then the pool itself just runs all jobs in order.
if not self._deps and MPI.get_size() > 1:
if not self._deps and self._comm.get_size() > 1:
# self._pool is set when the pool first tried to enqueue us, but we were still
# waiting for deps, in the `_enqueue` method below.
self._enqueue(self._pool)
Expand Down Expand Up @@ -544,6 +544,7 @@ def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None):
self._schedulers: list[concurrent.futures.Future] = []
self.id: int = None
self._scaffold = scaffold
self._comm = scaffold._comm
self._unhandled_errors = []
self._running_futures: list[concurrent.futures.Future] = []
self._mpipool: typing.Optional["MPIExecutor"] = None
Expand All @@ -556,7 +557,7 @@ def __init__(self, scaffold, fail_fast=False, workflow: "Workflow" = None):
self._fail_fast = fail_fast
self._workflow = workflow
self._cache_buffer = np.zeros(1000, dtype=np.uint64)
self._cache_window = MPI.window(self._cache_buffer)
self._cache_window = self._comm.window(self._cache_buffer)

def __enter__(self):
self._context = ExitStack()
Expand Down Expand Up @@ -605,7 +606,7 @@ def jobs(self) -> list[Job]:

@property
def parallel(self):
return MPI.get_size() > 1
return self._comm.get_size() > 1

@classmethod
def get_owner(cls, id):
Expand All @@ -620,7 +621,7 @@ def owner(self):
return self.get_owner(self.id)

def is_main(self):
return MPI.get_rank() == 0
return self._comm.get_rank() == 0

def get_submissions_of(self, submitter):
return [job for job in self._job_queue if job.submitter is submitter]
Expand Down Expand Up @@ -746,7 +747,7 @@ def _execute_parallel(self):
# master logic.

# Check if we need to abort our process due to errors etc.
abort = MPI.bcast(None)
abort = self._comm.bcast(None)
if abort:
raise WorkflowError(
"Unhandled exceptions during parallel execution.",
Expand Down Expand Up @@ -816,7 +817,7 @@ def _execute_parallel(self):
# Shut down our internal pool
self._mpipool.shutdown(wait=False, cancel_futures=True)
# Broadcast whether the worker nodes should raise an unhandled error.
MPI.bcast(self._workers_raise_unhandled)
self._comm.bcast(self._workers_raise_unhandled)

def _execute_serial(self):
# Wait for jobs to finish scheduling
Expand Down
32 changes: 25 additions & 7 deletions bsb/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .. import plugins
from ..exceptions import UnknownStorageEngineError
from ..services import MPI
from ..services.mpi import MPIService

if typing.TYPE_CHECKING:
from .interfaces import (
Expand Down Expand Up @@ -89,7 +89,16 @@ def get_engines():


def create_engine(name, root, comm):
# Create an engine from the engine's Engine interface.
"""
Create an engine from the engine's Engine interface.
:param str name: The name of the engine to create.
:param object root: An object that uniquely describes the storage, such as a filename
or path. The value to be provided depends on the engine. For the hdf5 engine
the filename has to be provided.
:param bsb.services.mpi.MPIService comm: MPI communicator that shares control over the
Engine interface.
"""
return get_engine_support(name)["Engine"](root, comm)


Expand Down Expand Up @@ -119,7 +128,7 @@ def __getattr__(self, attr):

class Storage:
"""
Factory class that produces all of the features and shims the functionality of the
Factory class that produces all the features and shims the functionality of the
underlying engine.
"""

Expand All @@ -143,7 +152,7 @@ def __init__(self, engine, root, comm=None, main=0, missing_ok=True):
:type comm: mpi4py.MPI.Comm
:param main: Rank of the MPI process that executes single-node tasks.
"""
self._comm = comm or MPI
self._comm = MPIService(comm)
self._engine = create_engine(engine, root, self._comm)
self._features = [
fname for fname, supported in view_support()[engine].items() if supported
Expand Down Expand Up @@ -244,7 +253,7 @@ def load(self):
"""
from ..core import Scaffold

return Scaffold(storage=self)
return Scaffold(storage=self, comm=self._comm.get_communicator())

def load_active_config(self):
"""
Expand Down Expand Up @@ -368,11 +377,20 @@ def get_chunk_stats(self):
return self._engine.get_chunk_stats()


def open_storage(root):
def open_storage(root, comm=None):
"""
Load a Storage object from its root.
:param root: Root (usually path) pointing to the storage object.
:param mpi4py.MPI.Comm comm: MPI communicator that shares control
over the Storage.
:returns: A network scaffold
:rtype: :class:`Storage`
"""
engines = get_engines()
for name, engine in engines.items():
if engine.peek_exists(root) and engine.recognizes(root):
return Storage(name, root, missing_ok=False)
return Storage(name, root, comm, missing_ok=False)
else:
for name, engine in engines.items():
if engine.peek_exists(root):
Expand Down
41 changes: 39 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import unittest

from bsb_test import NetworkFixture, RandomStorageFixture
from bsb_test import NetworkFixture, RandomStorageFixture, skip_serial, timeout
from mpi4py import MPI

from bsb import Configuration, PlacementSet, core
from bsb import Configuration, PlacementSet, Scaffold, Storage, core, get_engine_node


class TestCore(
Expand Down Expand Up @@ -97,6 +99,41 @@ def test_diagrams(self):
self.assertIn('cell2[label="cell2 (3 cell2)"]', storage_diagram)
self.assertIn('cell1 -> cell2[label="a_to_b (9)"]', storage_diagram)

@skip_serial
@timeout(3)
def test_mpi_from_storage(self):
self.network.compile(clear=True)
world = MPI.COMM_WORLD
if world.Get_rank() != 1:
# we make rank 1 skip while the others would load the network
group = world.group.Excl([1])
comm = world.Create_group(group)
core.from_storage(self.network.storage.root, comm)

@skip_serial
@timeout(3)
def test_mpi_compile(self):
world = MPI.COMM_WORLD
if world.Get_rank() != 1:
# we make rank 1 skip while the others would load the network
group = world.group.Excl([1])
comm = world.Create_group(group)
# Test compile with no storage
Scaffold(
Configuration.default(
storage=dict(engine="hdf5", root="test_network.hdf5")
),
comm=comm,
).compile(clear=True)
if world.Get_rank() == 0:
os.remove("test_network.hdf5")
# Test compile with external storage
s = Storage("hdf5", get_engine_node("hdf5")(engine="hdf5").root, comm=comm)
# self.cfg was modified when creating self.network but should update to match
# the new storage
Scaffold(self.cfg, storage=s, comm=comm).compile(clear=True)
s.remove()


class TestProfiling(
RandomStorageFixture, NetworkFixture, unittest.TestCase, engine_name="hdf5"
Expand Down
Loading

0 comments on commit 0dac469

Please sign in to comment.