Skip to content

Commit

Permalink
fix: update morphology introduce_point function (#884)
Browse files Browse the repository at this point in the history
* fix: update morphology introduce_point function

* fix: assert index range for introduce_point

* fix: apply suggestions from code review

Co-authored-by: Robin De Schepper <[email protected]>

* fix: split tests and add checks for properties

---------

Co-authored-by: Robin De Schepper <[email protected]>
  • Loading branch information
drodarie and Helveg authored Sep 13, 2024
1 parent 88a5996 commit 43ef230
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 13 deletions.
44 changes: 31 additions & 13 deletions bsb/morphologies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1489,28 +1489,46 @@ def get_label_mask(self, labels):
"""
return self.labels.get_mask(labels)

def introduce_point(self, index, *args, labels=None):
def introduce_point(self, index, position, radius=None, labels=None, properties=None):
"""
Insert a new point at ``index``, before the existing point at ``index``.
Radius, labels and extra properties can be set or will be copied from the
existing point at ``index``.
:param index: Index of the new point.
:type index: int
:param args: Vector coordinates of the new point
:type args: float
:param position: Coordinates of the new point
:type position: List[float]
:param radius: The radius to assign to the point.
:type radius: float
:param labels: The labels to assign to the point.
:type labels: list
:param properties: The properties to assign to the point.
:type properties: dict
"""
if index < 0 or index >= len(self.points):
raise IndexError(
f"Could not introduce point in branch at index {index}: out of bounds for branch length {len(self)}."
)
self._on_mutate()
for v, vector_name in enumerate(type(self).vectors):
vector = getattr(self, vector_name)
new_vector = np.concatenate((vector[:index], [args[v]], vector[index:]))
setattr(self, vector_name, new_vector)
if labels is None:
labels = set()
for label, mask in self._label_masks.items():
has_label = label in labels
new_mask = np.concatenate((mask[:index], [has_label], mask[index:]))
self._label_masks[label] = new_mask
old_labels = self.labels[index]
self.points = np.insert(self.points, index, position, 0)
self._labels = np.insert(self._labels, index, old_labels)
self._radii = np.insert(self._radii, index, radius or self._radii[index])
# By default, duplicate the existing property value ...
for k, v in self._properties.items():
self._properties[k] = np.insert(v, index, v[index])
if labels is not None:
self.label(labels, [index])
# ... and overwrite it with any new property values, if given.
if properties is not None:
for k, v in properties.items():
if k in self._properties:
self._properties[k][index] = v
else:
raise MorphologyError(
f"Property key '{k}' is not part of the Branch."
)

def introduce_arc_point(self, arc_val):
"""
Expand Down
54 changes: 54 additions & 0 deletions tests/test_morphologies.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,60 @@ def test_mlabel(self):
m.label(["B", "A"], [0, 1, 2])
self.assertEqual(3, np.sum(m.get_label_mask(["A"])[:3]), "then first 3 lbled")

def test_introduce_point_defaults(self):
x = 0

def _on_mutate():
nonlocal x
x += 1

b = Branch(
np.arange(12).reshape(4, 3),
np.arange(4),
properties={"tags": np.arange(4, 8)},
)
b._on_mutate = _on_mutate
b.label(["A"], [0, 1])
b.label(["B"], [2, 3])
b.introduce_point(1, [12, 13, 14])
self.assertAll(b.points[1] == np.array([12, 13, 14]))
self.assertAll(b.radii == np.array([0, 1, 1, 2, 3]))
self.assertAll(b.labels == np.array([1, 1, 1, 2, 2]))
self.assertAll(b.tags == np.array([4, 5, 5, 6, 7]))
self.assertTrue(x == 1)

def test_introduce_point(self):
b = Branch(
np.arange(12).reshape(4, 3),
np.arange(4),
properties={"tags": np.arange(4, 8), "other": np.arange(8, 12)},
)
b.label(["A"], [0, 1])
b.label(["B"], [2, 3])
b.introduce_point(3, [15, 16, 17], 4, ["E"], {"tags": 8})
self.assertAll(b.points[-2] == np.array([15, 16, 17]))
self.assertAll(b.radii == np.array([0, 1, 2, 4, 3]))
self.assertAll(b.labels == np.array([1, 1, 2, 3, 2]))
# tag of the new point set in properties
self.assertAll(b.tags == np.array([4, 5, 6, 8, 7]))
# other of the new point copied in from the following point
self.assertAll(b.other == np.array([8, 9, 10, 11, 11]))

def test_introduce_point_errors(self):
b = Branch(
np.arange(12).reshape(4, 3),
np.arange(4),
properties={"tags": np.arange(4, 8)},
)
b.label(["A"], [0, 1])
b.label(["B"], [2, 3])
with self.assertRaises(IndexError):
b.introduce_point(4, [15, 16, 17])
with self.assertRaises(IndexError):
b.introduce_point(-1, [15, 16, 17])
with self.assertRaises(MorphologyError):
b.introduce_point(0, [15, 16, 17], properties={"wrong_key": 7})


class TestPointSetters(NumpyTestCase, unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 43ef230

Please sign in to comment.