Skip to content

Commit

Permalink
fix: after hooks postprocessing for multi threads
Browse files Browse the repository at this point in the history
  • Loading branch information
drodarie committed Dec 11, 2024
1 parent 8b32630 commit 73989f5
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 8 deletions.
16 changes: 8 additions & 8 deletions bsb/postprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ class AfterPlacementHook(abc.ABC):
name: str = config.attr(key=True)

def queue(self, pool):
pool.queue(
lambda scaffold: scaffold.after_placement[self.name].postprocess(),
submitter=self,
)
def static_function(scaffold, name):
return scaffold.after_placement[name].postprocess()

pool.queue(static_function, (self.name,), submitter=self)

@abc.abstractmethod
def postprocess(self):
Expand All @@ -28,10 +28,10 @@ class AfterConnectivityHook(abc.ABC):
name: str = config.attr(key=True)

def queue(self, pool):
pool.queue(
lambda scaffold: scaffold.after_connectivity[self.name].postprocess(),
submitter=self,
)
def static_function(scaffold, name):
return scaffold.after_connectivity[name].postprocess()

pool.queue(static_function, (self.name,), submitter=self)

@abc.abstractmethod
def postprocess(self):
Expand Down
85 changes: 85 additions & 0 deletions tests/test_postprocessing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import unittest

from bsb_test import RandomStorageFixture

from bsb import (
MPI,
AfterConnectivityHook,
AfterPlacementHook,
Configuration,
Scaffold,
config,
)


class TestAfterConnectivityHook(
RandomStorageFixture, unittest.TestCase, engine_name="hdf5"
):
def setUp(self):
super().setUp()

@config.node
class TestAfterConn(AfterConnectivityHook):
def postprocess(self):
with open(f"test_after_conn_{MPI.get_rank()}.txt", "a") as f:
# make sure we have access to the scaffold context
f.write(f"{self.scaffold.configuration.name}\n")

self.network = Scaffold(
config=Configuration.default(
name="Test config",
after_connectivity={"test_after_conn": TestAfterConn()},
),
storage=self.storage,
)

def test_after_connectivity_job(self):
self.network.compile()
count_files = 0
for filename in os.listdir():
if filename.startswith(f"test_after_conn_{MPI.get_rank()}"):
count_files += 1
with open(filename, "r") as f:
lines = f.readlines()
self.assertEqual(
len(lines), 1, "The postprocess should be called only once."
)
self.assertEqual(lines[0], "Test config\n")
os.remove(filename)
self.assertEqual(count_files, 1)


class TestAfterPlacementHook(RandomStorageFixture, unittest.TestCase, engine_name="hdf5"):
def setUp(self):
super().setUp()

@config.node
class TestAfterPlace(AfterPlacementHook):
def postprocess(self):
with open(f"test_after_place_{MPI.get_rank()}.txt", "a") as f:
# make sure we have access to the scaffold context
f.write(f"{self.scaffold.configuration.name}\n")

self.network = Scaffold(
config=Configuration.default(
name="Test config",
after_placement={"test_after_placement": TestAfterPlace()},
),
storage=self.storage,
)

def test_after_placement_job(self):
self.network.compile()
count_files = 0
for filename in os.listdir():
if filename.startswith(f"test_after_place_{MPI.get_rank()}"):
count_files += 1
with open(filename, "r") as f:
lines = f.readlines()
self.assertEqual(
len(lines), 1, "The postprocess should be called only once."
)
self.assertEqual(lines[0], "Test config\n")
os.remove(filename)
self.assertEqual(count_files, 1)

0 comments on commit 73989f5

Please sign in to comment.