Skip to content

Commit

Permalink
feat: add shape attribute to types.ndarray
Browse files Browse the repository at this point in the history
Add documentation for cylinder targetting. fix minor typos in documentations
  • Loading branch information
drodarie committed Oct 3, 2024
1 parent 2a271b1 commit cabeb50
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 22 deletions.
10 changes: 7 additions & 3 deletions bsb/config/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,13 +789,17 @@ class ndarray(TypeHandler):
:rtype: Callable
"""

def __init__(self, dtype=None):
def __init__(self, shape=None, dtype=None):
self.shape = shape
self.dtype = dtype

def __call__(self, value):
result = np.array(value, copy=False)
if self.dtype is not None:
return np.array(value, copy=False, dtype=self.dtype)
return np.array(value, copy=False)
result = np.asarray(result, dtype=self.dtype)
if self.shape is not None:
result = result.reshape(self.shape)
return result

@property
def __name__(self):
Expand Down
32 changes: 16 additions & 16 deletions bsb/connectivity/geometric/geometric_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,15 +577,15 @@ class Ellipsoid(GeometricShape, classmap_entry="ellipsoid"):
"""

origin = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 0.0]
)
"""The coordinates of the center of the ellipsoid."""
lambdas = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[1.0, 0.5, 2.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[1.0, 0.5, 2.0]
)
"""The length of the three semi-axes."""

@config.property(type=types.ndarray(), required=True)
@config.property(type=types.ndarray(shape=(3,)), required=True)
def v0(self):
"""The versor on which the first semi-axis lies."""
return self._v0
Expand All @@ -594,7 +594,7 @@ def v0(self):
def v0(self, value):
self._v0 = np.copy(value) / np.linalg.norm(value)

@config.property(type=types.ndarray(), required=True)
@config.property(type=types.ndarray(shape=(3,)), required=True)
def v1(self):
"""The versor on which the second semi-axis lies."""
return self._v1
Expand All @@ -603,7 +603,7 @@ def v1(self):
def v1(self, value):
self._v1 = np.copy(value) / np.linalg.norm(value)

@config.property(type=types.ndarray(), required=True)
@config.property(type=types.ndarray(shape=(3,)), required=True)
def v2(self):
"""The versor on which the third semi-axis lies."""
return self._v2
Expand Down Expand Up @@ -700,11 +700,11 @@ class Cone(GeometricShape, classmap_entry="cone"):
"""

apex = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 1.0, 0.0]
)
"""The coordinates of the apex of the cone."""
origin = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 0.0]
)
"""The coordinates of the center of the cone's base."""
radius = config.attr(type=float, required=False, default=1.0)
Expand Down Expand Up @@ -824,11 +824,11 @@ class Cylinder(GeometricShape, classmap_entry="cylinder"):
"""

origin = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 0.0]
)
"""The coordinates of the center of the bottom circle of the cylinder."""
top_center = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 2.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 2.0, 0.0]
)
"""The coordinates of the center of the top circle of the cylinder."""
radius = config.attr(type=float, required=False, default=1.0)
Expand Down Expand Up @@ -936,7 +936,7 @@ class Sphere(GeometricShape, classmap_entry="sphere"):
"""

origin = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 0.0]
)
"""The coordinates of the center of the sphere."""
radius = config.attr(type=float, required=False, default=1.0)
Expand Down Expand Up @@ -1008,11 +1008,11 @@ class Cuboid(GeometricShape, classmap_entry="cuboid"):
"""

origin = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 0.0]
)
"""The coordinates of the center of the barycenter of the bottom rectangle."""
top_center = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 1.0, 0.0]
)
"""The coordinates of the center of the barycenter of the top rectangle."""
side_length_1 = config.attr(type=float, required=False, default=1.0)
Expand Down Expand Up @@ -1159,21 +1159,21 @@ class Parallelepiped(GeometricShape, classmap_entry="parallelepiped"):
"""

origin = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 0.0]
)
"""The coordinates of the left-bottom edge."""
side_vector_1 = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[1.0, 0.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[1.0, 0.0, 0.0]
)
"""The first vector identifying the parallelepiped (using the right-hand orientation: the
thumb)."""
side_vector_2 = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 1.0, 0.0]
)
"""The second vector identifying the parallelepiped (using the right-hand orientation: the
index)."""
side_vector_3 = config.attr(
type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 1.0]
type=types.ndarray(shape=(3,), dtype=float), required=True, hint=[0.0, 0.0, 1.0]
)
"""The third vector identifying the parallelepiped (using the right-hand orientation: the
middle finger)."""
Expand Down
2 changes: 1 addition & 1 deletion bsb/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def place_cells(
chunk=None,
):
"""
Place cells inside of the scaffold
Place cells inside the scaffold
.. code-block:: python
Expand Down
2 changes: 1 addition & 1 deletion bsb/services/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def execute(self, return_results=False):
"""
Execute the jobs in the queue
In serial execution this runs all of the jobs in the queue in First In First Out
In serial execution this runs all the jobs in the queue in First In First Out
order. In parallel execution this enqueues all jobs into the MPIPool unless they
have dependencies that need to complete first.
"""
Expand Down
5 changes: 4 additions & 1 deletion bsb/simulation/targetting.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,14 @@ class CylindricalTargetting(
Targets all cells in a cylinder along specified axis.
"""

origin: np.ndarray[float] = config.attr(type=types.ndarray(dtype=float))
origin: np.ndarray[float] = config.attr(type=types.ndarray(shape=(2,), dtype=float))
"""Coordinates of the base of the cylinder for each non main axis"""
axis: typing.Union[typing.Literal["x"], typing.Literal["y"], typing.Literal["z"]] = (
config.attr(type=types.in_(["x", "y", "z"]), default="y")
)
"""Main axis of the cylinder"""
radius: float = config.attr(type=float, required=True)
"""Radius of the cylinder"""

@FractionFilter.filter
def get_targets(self, adapter, simulation, simdata):
Expand Down

0 comments on commit cabeb50

Please sign in to comment.