diff --git a/.github/devops/generate_public_api.py b/.github/devops/generate_public_api.py
index df32c072..378bfa23 100644
--- a/.github/devops/generate_public_api.py
+++ b/.github/devops/generate_public_api.py
@@ -36,6 +36,10 @@ def get_public_api_map():
if isinstance(el, ast.Constant)
]
for api in module_api:
+ if api in public_api_map:
+ raise RuntimeError(
+ f"Duplicate api key: bsb.{module}.{api} and bsb.{public_api_map[api]}.{api}"
+ )
public_api_map[api] = module
return public_api_map
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 748ece20..705ff3cd 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -7,4 +7,10 @@ repos:
rev: 5.12.0
hooks:
- id: isort
- name: isort (python)
\ No newline at end of file
+ name: isort (python)
+ - repo: local
+ hooks:
+ - id: api-test
+ name: api-test
+ entry: python3 .github/devops/generate_public_api.py
+ language: system
diff --git a/CHANGELOG b/CHANGELOG
index 729feaab..609fb346 100644
--- a/CHANGELOG
+++ b/CHANGELOG
@@ -1,3 +1,19 @@
+# 4.1.1
+* Fix reference of file_ref during configuration parsing when importing nodes.
+* Use a more strict rule for Jobs enqueuing.
+* Use certifi to fix ssl certificate issues.
+
+# 4.1.0
+* Added `ParsesReferences` mixin from bsb-json to allow reference and import in configuration files.
+ This includes also a recursive parsing of the configuration files.
+* Added `swap_axes` function in morphologies
+* Added API test in pre-commit config and fix duplicate entries.
+* Fix `PackageRequirement`, `ConfigurationListAttribute`, and `ConfigurationDictAttribute`
+ inverse functions
+* Refactor `CodeDependencyNode` module attribute to be either a module like string or a path string
+* Fix of assert_same_len
+* Fix of test_particle_vd
+
# 4.0.0 - Too much to list
## 40.0.0a32
diff --git a/README.md b/README.md
index dcf5db0f..09503cef 100644
--- a/README.md
+++ b/README.md
@@ -1,7 +1,7 @@
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
[![Documentation Status](https://readthedocs.org/projects/bsb/badge/?version=latest)](https://bsb.readthedocs.io/en/latest/?badge=latest)
-[![Build Status](https://travis-ci.com/dbbs-lab/bsb.svg?branch=master)](https://travis-ci.com/dbbs-lab/bsb)
-[![codecov](https://codecov.io/gh/dbbs-lab/bsb/branch/master/graph/badge.svg)](https://codecov.io/gh/dbbs-lab/bsb)
+[![Build Status](https://travis-ci.com/dbbs-lab/bsb-core.svg?branch=main)](https://travis-ci.com/dbbs-lab/bsb-core)
+[![codecov](https://codecov.io/gh/dbbs-lab/bsb-core/branch/main/graph/badge.svg)](https://codecov.io/gh/dbbs-lab/bsb-core)
:closed_book: Read the documentation on https://bsb.readthedocs.io/en/latest
@@ -25,14 +25,14 @@ Any package in the BSB ecosystem can be installed from PyPI through `pip`. Most
will want to install the main [bsb](https://pypi.org/project/bsb/) framework:
```
-pip install "bsb~=4.0"
+pip install "bsb~=4.1"
```
Advanced users looking to control install an unconventional combination of plugins might
be better of installing just this package, and the desired plugins:
```
-pip install "bsb-core~=4.0"
+pip install "bsb-core~=4.1"
```
Note that installing `bsb-core` does not come with any plugins installed and the usually
diff --git a/bsb/__init__.py b/bsb/__init__.py
index 8d4f09a9..5b9bfe5f 100644
--- a/bsb/__init__.py
+++ b/bsb/__init__.py
@@ -7,7 +7,7 @@
install the `bsb` package instead.
"""
-__version__ = "4.0.1"
+__version__ = "4.1.1"
import functools
import importlib
@@ -142,7 +142,7 @@ def __dir__():
BaseCommand: typing.Type["bsb.cli.commands.BaseCommand"]
BidirectionalContact: typing.Type["bsb.postprocessing.BidirectionalContact"]
BootError: typing.Type["bsb.exceptions.BootError"]
-BoxTree: typing.Type["bsb.voxels.BoxTree"]
+BoxTree: typing.Type["bsb.trees.BoxTree"]
BoxTreeInterface: typing.Type["bsb.trees.BoxTreeInterface"]
Branch: typing.Type["bsb.morphologies.Branch"]
BranchLocTargetting: typing.Type["bsb.simulation.targetting.BranchLocTargetting"]
@@ -216,6 +216,8 @@ def __dir__():
ExternalSourceError: typing.Type["bsb.exceptions.ExternalSourceError"]
FileDependency: typing.Type["bsb.storage._files.FileDependency"]
FileDependencyNode: typing.Type["bsb.storage._files.FileDependencyNode"]
+FileImportError: typing.Type["bsb.exceptions.FileImportError"]
+FileReferenceError: typing.Type["bsb.exceptions.FileReferenceError"]
FileScheme: typing.Type["bsb.storage._files.FileScheme"]
FileStore: typing.Type["bsb.storage.interfaces.FileStore"]
FixedIndegree: typing.Type["bsb.connectivity.general.FixedIndegree"]
@@ -292,6 +294,7 @@ def __dir__():
ParameterError: typing.Type["bsb.exceptions.ParameterError"]
ParameterValue: typing.Type["bsb.simulation.parameter.ParameterValue"]
ParserError: typing.Type["bsb.exceptions.ParserError"]
+ParsesReferences: typing.Type["bsb.config.parsers.ParsesReferences"]
Partition: typing.Type["bsb.topology.partition.Partition"]
PlacementError: typing.Type["bsb.exceptions.PlacementError"]
PlacementIndications: typing.Type["bsb.cell_types.PlacementIndications"]
diff --git a/bsb/_util.py b/bsb/_util.py
index a3759661..0cb530b7 100644
--- a/bsb/_util.py
+++ b/bsb/_util.py
@@ -95,7 +95,7 @@ def assert_samelen(*args):
"""
len_ = None
assert all(
- (len_ := len(arg) if len_ is None else len(arg)) == len_ for arg in args
+ ((len_ := len(arg)) if len_ is None else len(arg)) == len_ for arg in args
), "Input arguments should be of same length."
diff --git a/bsb/config/_attrs.py b/bsb/config/_attrs.py
index 4c53c1fe..9aec9e56 100644
--- a/bsb/config/_attrs.py
+++ b/bsb/config/_attrs.py
@@ -3,6 +3,7 @@
"""
import builtins
+from functools import wraps
import errr
@@ -476,7 +477,7 @@ def __set__(self, instance, value):
self.attr_name,
) from e
self.flag_dirty(instance)
- # The value was cast to its intented type and the new value can be set.
+ # The value was cast to its intended type and the new value can be set.
self.fset(instance, value)
root = _strict_root(instance)
if _is_booted(root):
@@ -687,7 +688,15 @@ def fill(self, value, _parent, _key=None):
def _set_type(self, type, key=None):
self.child_type = super()._set_type(type, key=False)
- return self.fill
+
+ @wraps(self.fill)
+ def wrapper(*args, **kwargs):
+ return self.fill(*args, **kwargs)
+
+ # Expose children __inv__ function if it exists
+ if hasattr(self.child_type, "__inv__"):
+ setattr(wrapper, "__inv__", self.child_type.__inv__)
+ return wrapper
def tree(self, instance):
val = _getattr(instance, self.attr_name)
@@ -838,7 +847,15 @@ def fill(self, value, _parent, _key=None):
def _set_type(self, type, key=None):
self.child_type = super()._set_type(type, key=False)
- return self.fill
+
+ @wraps(self.fill)
+ def wrapper(*args, **kwargs):
+ return self.fill(*args, **kwargs)
+
+ # Expose children __inv__ function if it exists
+ if hasattr(self.child_type, "__inv__"):
+ setattr(wrapper, "__inv__", self.child_type.__inv__)
+ return wrapper
def tree(self, instance):
val = _getattr(instance, self.attr_name).items()
diff --git a/bsb/config/_config.py b/bsb/config/_config.py
index 7c8464cb..599b806e 100644
--- a/bsb/config/_config.py
+++ b/bsb/config/_config.py
@@ -7,7 +7,7 @@
from ..cell_types import CellType
from ..connectivity import ConnectionStrategy
from ..placement import PlacementStrategy
-from ..postprocessing import AfterPlacementHook
+from ..postprocessing import AfterConnectivityHook, AfterPlacementHook
from ..simulation.simulation import Simulation
from ..storage._files import (
CodeDependencyNode,
@@ -132,8 +132,8 @@ class Configuration:
"""
Network connectivity strategies
"""
- after_connectivity: cfgdict[str, AfterPlacementHook] = config.dict(
- type=AfterPlacementHook,
+ after_connectivity: cfgdict[str, AfterConnectivityHook] = config.dict(
+ type=AfterConnectivityHook,
)
simulations: cfgdict[str, Simulation] = config.dict(
type=Simulation,
diff --git a/bsb/config/_parse_types.py b/bsb/config/_parse_types.py
new file mode 100644
index 00000000..f16507c0
--- /dev/null
+++ b/bsb/config/_parse_types.py
@@ -0,0 +1,174 @@
+from __future__ import annotations
+
+from .. import warn
+from ..exceptions import ConfigurationWarning, FileImportError
+
+
+class parsed_node:
+ _key = None
+ """Key to reference the node"""
+ _parent: parsed_node = None
+ """Parent node"""
+
+ def location(self):
+ return "/" + "/".join(str(part) for part in self._location_parts([]))
+
+ def _location_parts(self, carry):
+ parent = self
+ while (parent := parent._parent) is not None:
+ if parent._parent is not None:
+ carry.insert(0, parent._key)
+ carry.append(self._key or "")
+ return carry
+
+ def __str__(self):
+ return f""
+
+ def __repr__(self):
+ return super().__str__()
+
+
+def _traverse_wrap(node, iter):
+ for key, value in iter:
+ if type(value) in recurse_handlers:
+ value, iter = recurse_handlers[type(value)](value, node)
+ value._key = key
+ value._parent = node
+ node[key] = value
+ _traverse_wrap(value, iter)
+
+
+class parsed_dict(dict, parsed_node):
+ def merge(self, other):
+ """
+ Recursively merge the values of another dictionary into us
+ """
+ for key, value in other.items():
+ if key in self and isinstance(self[key], dict) and isinstance(value, dict):
+ if not isinstance(self[key], parsed_dict): # pragma: nocover
+ self[key] = d = parsed_dict(self[key])
+ d._key = key
+ d._parent = self
+ self[key].merge(value)
+ elif isinstance(value, dict):
+ self[key] = d = parsed_dict(value)
+ d._key = key
+ d._parent = self
+ _traverse_wrap(d, d.items())
+ else:
+ if isinstance(value, list):
+ value = parsed_list(value)
+ value._key = key
+ value._parent = self
+ self[key] = value
+
+ def rev_merge(self, other):
+ """
+ Recursively merge ourselves onto another dictionary
+ """
+ m = parsed_dict(other)
+ _traverse_wrap(m, m.items())
+ m.merge(self)
+ self.clear()
+ self.update(m)
+ for v in self.values():
+ if hasattr(v, "_parent"):
+ v._parent = self
+
+
+class parsed_list(list, parsed_node):
+ pass
+
+
+def _prep_dict(node, parent):
+ return parsed_dict(node), node.items()
+
+
+def _prep_list(node, parent):
+ return parsed_list(node), enumerate(node)
+
+
+recurse_handlers = {
+ dict: _prep_dict,
+ parsed_dict: _prep_dict,
+ list: _prep_list,
+ parsed_list: _prep_list,
+}
+
+
+class file_ref:
+ def __init__(self, node, doc, ref):
+ self.node = node
+ self.doc = doc
+ self.ref = ref
+ self.key_path = node.location()
+
+ def resolve(self, parser, target):
+ del self.node["$ref"]
+ self.node.rev_merge(target)
+
+ def __str__(self):
+ return "".format(((self.doc + "#") if self.doc else "") + self.ref)
+
+
+class file_imp(file_ref):
+ def __init__(self, node, doc, ref, values):
+ super().__init__(node, doc, ref)
+ self.values = values
+
+ def resolve(self, parser, target):
+ del self.node["$import"]
+ for key in self.values:
+ if key not in target:
+ raise FileImportError(
+ "'{}' does not exist in import node '{}'".format(key, self.ref)
+ )
+ if isinstance(target[key], dict):
+ imported = parsed_dict()
+ imported.merge(target[key])
+ imported._key = key
+ imported._parent = self.node
+ if key in self.node:
+ if isinstance(self.node[key], dict):
+ imported.merge(self.node[key])
+ else:
+ warn(
+ f"Importkey '{key}' of {self} is ignored because the parent"
+ f" already contains a key '{key}'"
+ f" with value '{self.node[key]}'.",
+ ConfigurationWarning,
+ stacklevel=3,
+ )
+ continue
+ self.node[key] = imported
+ self._fix_references(self.node[key], parser)
+ elif isinstance(target[key], list):
+ imported, iter = _prep_list(target[key], self.node)
+ imported._key = key
+ imported._parent = self.node
+ self.node[key] = imported
+ self._fix_references(self.node[key], parser)
+ _traverse_wrap(imported, iter)
+ else:
+ self.node[key] = target[key]
+
+ def _fix_references(self, node, parser):
+ # fix parser's references after the import.
+ if hasattr(parser, "references"):
+ for ref in parser.references:
+ node_loc = node.location()
+ if node_loc in ref.key_path:
+ # need to update the reference
+ # we descend the tree from the node until we reach the ref
+ # It should be here because of the merge.
+ loc_node = node
+ while node_loc != ref.key_path:
+ key = ref.key_path.split(node_loc, 1)[-1].split("/", 1)[-1]
+ if key not in loc_node: # pragma: nocover
+ raise ParserError(
+ f"Reference {ref.key_path} not found in {node_loc}. "
+ f"Should have been merged."
+ )
+ loc_node = node[key]
+ node_loc += "/" + key
+ ref.node = loc_node
diff --git a/bsb/config/parsers.py b/bsb/config/parsers.py
index fb455e1f..5b5c6061 100644
--- a/bsb/config/parsers.py
+++ b/bsb/config/parsers.py
@@ -1,19 +1,195 @@
import abc
import functools
+import os
-from ..exceptions import PluginError
+from ..exceptions import FileReferenceError, PluginError
+from ._parse_types import file_imp, file_ref, parsed_dict, recurse_handlers
class ConfigurationParser(abc.ABC):
@abc.abstractmethod
- def parse(self, content, path=None):
+ def parse(self, content, path=None): # pragma: nocover
+ """
+ Parse configuration file content.
+
+ :param content: str or dict content of the file to parse
+ :param path: path to the file containing the configuration.
+ :return: configuration tree and metadata attached as dictionaries
+ """
pass
@abc.abstractmethod
- def generate(self, tree, pretty=False):
+ def generate(self, tree, pretty=False): # pragma: nocover
+ """
+ Generate a string representation of the configuration tree (dictionary).
+
+ :param dict tree: configuration tree
+ :param bool pretty: if True, will add indentation to the output string
+ :return: str representation of the configuration tree
+ :rtype: str
+ """
pass
+class ParsesReferences:
+ """
+ Mixin to decorate parse function of ConfigurationParser.
+ Allows for imports and references inside configuration files.
+ """
+
+ def __init_subclass__(cls, **kwargs):
+ super().__init_subclass__(**kwargs)
+ parse = cls.parse
+
+ def parse_with_references(self, content, path=None):
+ """Traverse the parsed tree and resolve any `$ref` and `$import`"""
+ content, meta = parse(self, content, path)
+ content = parsed_dict(content)
+ self.root = content
+ self.path = path or os.getcwd()
+ self.is_doc = path and not os.path.isdir(path)
+ self.references = []
+ self.documents = {}
+ self._traverse(content, content.items())
+ self.resolved_documents = {}
+ self._resolve_documents()
+ self._resolve_references()
+ return content, meta
+
+ cls.parse = parse_with_references
+
+ def _traverse(self, node, iter):
+ # Iterates over all values in `iter` and checks for import keys, recursion or refs
+ # Also wraps all nodes in their `parsed_*` counterparts.
+ for key, value in iter:
+ if self._is_import(key):
+ self._store_import(node)
+ elif type(value) in recurse_handlers:
+ # The recurse handlers wrap the dicts and lists and return appropriate
+ # iterators for them.
+ value, iter = recurse_handlers[type(value)](value, node)
+ # Set some metadata on the wrapped recursable objects.
+ value._key = key
+ value._parent = node
+ # Overwrite the reference to the original object with a reference to the
+ # wrapped object.
+ node[key] = value
+ # Recurse a level deeper
+ self._traverse(value, iter)
+ elif self._is_reference(key):
+ self._store_reference(node, value)
+
+ def _is_reference(self, key):
+ return key == "$ref"
+
+ def _is_import(self, key):
+ return key == "$import"
+
+ def _get_ref_document(self, ref, base=None):
+ if "#" not in ref or ref.split("#")[0] == "":
+ return None
+ doc = ref.split("#")[0]
+ if not os.path.isabs(doc):
+ if not base:
+ # reference should be relative to the current configuration file
+ # to avoid recurrence issues.
+ base = os.path.dirname(self.path)
+ elif not os.path.isdir(base):
+ base = os.path.dirname(base)
+ if not os.path.exists(base):
+ raise IOError("Can't find reference directory '{}'".format(base))
+ doc = os.path.abspath(os.path.join(base, doc))
+ return doc
+
+ @staticmethod
+ def _get_absolute_ref(node, ref):
+ ref = ref.split("#")[-1]
+ if ref.startswith("/"):
+ path = ref
+ else:
+ path = os.path.join(node.location(), ref)
+ return os.path.normpath(path).replace(os.path.sep, "/")
+
+ def _store_reference(self, node, ref):
+ # Analyzes the reference and creates a ref object from the given data
+ doc = self._get_ref_document(ref, self.path)
+ ref = self._get_absolute_ref(node, ref)
+ if doc not in self.documents:
+ self.documents[doc] = set()
+ self.documents[doc].add(ref)
+ self.references.append(file_ref(node, doc, ref))
+
+ def _store_import(self, node):
+ # Analyzes the import node and creates a ref object from the given data
+ imp = node["$import"]
+ ref = imp["ref"]
+ doc = self._get_ref_document(ref)
+ ref = self._get_absolute_ref(node, ref)
+ if doc not in self.documents:
+ self.documents[doc] = set()
+ self.documents[doc].add(ref)
+ if "values" not in imp:
+ e = RuntimeError(f"Import node {node} is missing a 'values' list.")
+ e._bsbparser_show_user = True
+ raise e
+ self.references.append(file_imp(node, doc, ref, imp["values"]))
+
+ def _resolve_documents(self):
+ # Iterates over the list of stored documents parses them and fetches the content
+ # of each reference node.
+ for file, refs in self.documents.items():
+ if file is None:
+ content = self.root
+ else:
+ from . import _try_parsers
+
+ parser_classes = get_configuration_parser_classes()
+ ext = file.split(".")[-1]
+ with open(file, "r") as f:
+ content = f.read()
+ _, content, _ = _try_parsers(content, parser_classes, ext, path=file)
+ try:
+ self.resolved_documents[file] = self._resolve_document(content, refs)
+ except FileReferenceError as jre:
+ if not file:
+ raise
+ raise FileReferenceError(
+ str(jre) + " in document '{}'".format(file)
+ ) from None
+
+ def _resolve_document(self, content, refs):
+ resolved = {}
+ for ref in refs:
+ resolved[ref] = self._fetch_reference(content, ref)
+ return resolved
+
+ def _fetch_reference(self, content, ref):
+ parts = [p for p in ref.split("/")[1:] if p]
+ n = content
+ loc = ""
+ for part in parts:
+ loc += "/" + part
+ try:
+ n = n[part]
+ except KeyError:
+ raise FileReferenceError(
+ "'{}' in File reference '{}' does not exist".format(loc, ref)
+ ) from None
+ if not isinstance(n, dict):
+ raise FileReferenceError(
+ "File references can only point to dictionaries. '{}' is a {}".format(
+ "{}' in '{}".format(loc, ref) if loc != ref else ref,
+ type(n).__name__,
+ )
+ )
+ return n
+
+ def _resolve_references(self):
+ for ref in self.references:
+ target = self.resolved_documents[ref.doc][ref.ref]
+ ref.resolve(self, target)
+
+
@functools.cache
def get_configuration_parser_classes():
from ..plugins import discover
@@ -36,6 +212,7 @@ def get_configuration_parser(parser, **kwargs):
__all__ = [
"ConfigurationParser",
+ "ParsesReferences",
"get_configuration_parser",
"get_configuration_parser_classes",
]
diff --git a/bsb/config/types.py b/bsb/config/types.py
index aee713d4..ceea9315 100644
--- a/bsb/config/types.py
+++ b/bsb/config/types.py
@@ -740,6 +740,40 @@ def requirement(section):
return requirement
+def same_size(*list_attrs, required=True):
+ """
+ Requirement handler for list attributes that should have the same size.
+
+ :param list_attrs: The keys of the list attributes.
+ :type list_attrs: str
+ :param required: Whether at least one of the keys is required
+ :type required: bool
+ :returns: Requirement function
+ :rtype: Callable
+ """
+ listed = ", ".join(f"`{m}`" for m in list_attrs[:-1])
+ if len(list_attrs) > 1:
+ listed += f" {{}} `{list_attrs[-1]}`"
+
+ def requirement(section):
+ common_size = -1
+ count = 0
+ for m in list_attrs:
+ if m in section:
+ v = builtins.list(section[m])
+ if len(v) != common_size and common_size >= 0:
+ err_msg = f"The {listed} attributes should have the same size."
+ raise RequirementError(err_msg)
+ common_size = len(v)
+ count += 1
+ if not count == len(list_attrs) and required:
+ err_msg = f"The {listed} attributes are required."
+ raise RequirementError(err_msg)
+ return False
+
+ return requirement
+
+
def shortform():
def requirement(section):
return not section.is_shortform
@@ -755,7 +789,12 @@ class ndarray(TypeHandler):
:rtype: Callable
"""
+ def __init__(self, dtype=None):
+ self.dtype = dtype
+
def __call__(self, value):
+ if self.dtype is not None:
+ return np.array(value, copy=False, dtype=self.dtype)
return np.array(value, copy=False)
@property
@@ -780,14 +819,16 @@ class PackageRequirement(TypeHandler):
def __call__(self, value):
from packaging.requirements import Requirement
- return Requirement(value)
+ requirement = Requirement(value)
+ requirement._cfg_inv = value
+ return requirement
@property
def __name__(self):
return "package requirement"
def __inv__(self, value):
- return str(value)
+ return getattr(value, "_cfg_inv", builtins.str(value))
def __hint__(self):
return "numpy==1.24.0"
diff --git a/bsb/connectivity/__init__.py b/bsb/connectivity/__init__.py
index 42dda4d3..19b20c2d 100644
--- a/bsb/connectivity/__init__.py
+++ b/bsb/connectivity/__init__.py
@@ -5,4 +5,5 @@
# isort: on
from .detailed import *
from .general import *
+from .geometric import *
from .import_ import CsvImportConnectivity
diff --git a/bsb/connectivity/detailed/shared.py b/bsb/connectivity/detailed/shared.py
index 928bf2a2..3754ced1 100644
--- a/bsb/connectivity/detailed/shared.py
+++ b/bsb/connectivity/detailed/shared.py
@@ -1,4 +1,3 @@
-from functools import cache
from itertools import chain
import numpy as np
@@ -15,8 +14,8 @@ class Intersectional:
def get_region_of_interest(self, chunk):
post_ps = [ct.get_placement_set() for ct in self.postsynaptic.cell_types]
- lpre, upre = self._get_rect_ext(tuple(chunk.dimensions), True)
- lpost, upost = self._get_rect_ext(tuple(chunk.dimensions), False)
+ lpre, upre = self.presynaptic._get_rect_ext(tuple(chunk.dimensions))
+ lpost, upost = self.postsynaptic._get_rect_ext(tuple(chunk.dimensions))
# Get the `np.arange`s between bounds offset by the chunk position, to be used in
# `np.meshgrid` below.
bounds = list(
@@ -42,27 +41,6 @@ def get_region_of_interest(self, chunk):
size = next(iter(self._occ_chunks)).dimensions
return [t for c in clist if (t := Chunk(c, size)) in self._occ_chunks]
- @cache
- def _get_rect_ext(self, chunk_size, pre_post_flag):
- if pre_post_flag:
- types = self.presynaptic.cell_types
- loader = self.presynaptic.morpho_loader
- else:
- types = self.postsynaptic.cell_types
- loader = self.postsynaptic.morpho_loader
- ps_list = [ct.get_placement_set() for ct in types]
- ms_list = [loader(ps) for ps in ps_list]
- if not sum(map(len, ms_list)):
- # No cells placed, return smallest possible RoI.
- return [np.array([0, 0, 0]), np.array([0, 0, 0])]
- metas = list(chain.from_iterable(ms.iter_meta(unique=True) for ms in ms_list))
- # TODO: Combine morphology extension information with PS rotation information.
- # Get the chunk coordinates of the boundaries of this chunk convoluted with the
- # extension of the intersecting morphologies.
- lbounds = np.min([m["ldc"] for m in metas], axis=0) // chunk_size
- ubounds = np.max([m["mdc"] for m in metas], axis=0) // chunk_size
- return lbounds, ubounds
-
def candidate_intersection(self, target_coll, candidate_coll):
target_cache = [
(tset.cell_type, tset, tset.load_boxes()) for tset in target_coll.placement
diff --git a/bsb/connectivity/geometric/__init__.py b/bsb/connectivity/geometric/__init__.py
new file mode 100644
index 00000000..75e397bf
--- /dev/null
+++ b/bsb/connectivity/geometric/__init__.py
@@ -0,0 +1,4 @@
+from .geometric_shapes import *
+from .morphology_shape_intersection import MorphologyToShapeIntersection
+from .shape_morphology_intersection import ShapeToMorphologyIntersection
+from .shape_shape_intersection import ShapeHemitype, ShapeToShapeIntersection
diff --git a/bsb/connectivity/geometric/geometric_shapes.py b/bsb/connectivity/geometric/geometric_shapes.py
new file mode 100644
index 00000000..387596d2
--- /dev/null
+++ b/bsb/connectivity/geometric/geometric_shapes.py
@@ -0,0 +1,1290 @@
+from __future__ import annotations
+
+import abc
+from typing import List, Tuple
+
+import numpy as np
+from scipy.interpolate import interp1d
+from scipy.spatial.transform import Rotation as R
+
+from ... import config
+from ...config import types
+
+
+def _reshape_vectors(rot_pts, x, y, z):
+ xrot = rot_pts[:, 0].reshape(x.shape)
+ yrot = rot_pts[:, 1].reshape(y.shape)
+ zrot = rot_pts[:, 2].reshape(z.shape)
+
+ return xrot, yrot, zrot
+
+
+def rotate_3d_mesh_by_vec(
+ x: np.ndarray, y: np.ndarray, z: np.ndarray, rot_versor: np.ndarray, angle: float
+):
+ """
+ Rotate meshgrid points according to a rotation versor and angle.
+
+ :param numpy.ndarray[numpy.ndarray[float]] x: x coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] y: y coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] z: z coordinate points of the meshgrid
+ :param numpy.ndarray[float] rot_versor: vector representing rotation versor
+ :param float angle: rotation angle in radian
+ :return: Rotated x, y, z coordinate points
+ :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]]
+ """
+
+ # Arrange point coordinates in shape (N, 3) for vectorized processing
+ pts = np.array([x.ravel(), y.ravel(), z.ravel()]).transpose()
+
+ # Create and apply rotation
+ rot = R.from_rotvec(rot_versor * angle)
+ rot_pts = rot.apply(pts)
+
+ # return to original shape of meshgrid
+ return _reshape_vectors(rot_pts, x, y, z)
+
+
+def translate_3d_mesh_by_vec(
+ x: np.ndarray, y: np.ndarray, z: np.ndarray, t_vec: np.ndarray
+):
+ """
+ Translate meshgrid points according to a 3d vector.
+
+ :param numpy.ndarray[numpy.ndarray[float]] x: x coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] y: y coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] z: z coordinate points of the meshgrid
+ :param numpy.ndarray[float] t_vec: translation vector
+ :return: Translated x, y, z coordinate points
+ :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]]
+ """
+
+ # Arrange point coordinates in shape (N, 3) for vectorized processing
+ pts = np.array([x.ravel(), y.ravel(), z.ravel()]).transpose()
+ pts = pts + t_vec
+ # return to original shape of meshgrid
+ return _reshape_vectors(pts, x, y, z)
+
+
+def rotate_3d_mesh_by_rot_mat(
+ x: np.ndarray, y: np.ndarray, z: np.ndarray, rot_mat: np.ndarray
+):
+ """
+ Rotate meshgrid points according to a rotation matrix.
+
+ :param numpy.ndarray[numpy.ndarray[float]] x: x coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] y: y coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] z: z coordinate points of the meshgrid
+ :param numpy.ndarray[numpy.ndarray[float]] rot_mat: rotation matrix, shape (3,3)
+ :return: Rotated x, y, z coordinate points
+ :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]]
+ """
+
+ # Arrange point coordinates in shape (N, 3) for vectorized processing
+ pts = np.array([x.ravel(), y.ravel(), z.ravel()]).transpose()
+
+ # Create and apply rotation
+ rot = R.from_matrix(rot_mat)
+ rot_pts = rot.apply(pts)
+
+ # return to original shape of meshgrid
+ return _reshape_vectors(rot_pts, x, y, z)
+
+
+def _surface_resampling(
+ surface_function,
+ theta_min=0,
+ theta_max=2 * np.pi,
+ phi_min=0,
+ phi_max=np.pi,
+ precision=25,
+):
+ # first sampling to estimate surface distribution
+ theta, phi = np.meshgrid(
+ np.linspace(theta_min, theta_max, precision),
+ np.linspace(phi_min, phi_max, precision),
+ )
+ coords = surface_function(theta, phi)
+
+ # estimate surfaces, decomposing it into parallelograms along theta and phi
+ delta_t_temp = np.diff(coords, axis=2)
+ delta_u_temp = np.diff(coords, axis=1)
+ delta_t = np.zeros(coords.shape)
+ delta_u = np.zeros(coords.shape)
+ delta_t[: coords.shape[0], : coords.shape[1], 1 : coords.shape[2]] = delta_t_temp
+ delta_u[: coords.shape[0], 1 : coords.shape[1], : coords.shape[2]] = delta_u_temp
+ delta_S = np.linalg.norm(np.cross(delta_t, delta_u, 0, 0), axis=2)
+ cum_S_t = np.cumsum(delta_S.sum(axis=0))
+ cum_S_u = np.cumsum(delta_S.sum(axis=1))
+ return theta, phi, cum_S_t, cum_S_u
+
+
+def uniform_surface_sampling(
+ n_points,
+ surface_function,
+ theta_min=0,
+ theta_max=2 * np.pi,
+ phi_min=0,
+ phi_max=np.pi,
+ precision=25,
+):
+ """
+ Uniform-like random sampling of polar coordinates based on surface estimation.
+ This sampling is useful on elliptic surfaces (e.g. sphere).
+ Algorithm based on https://github.com/maxkapur/param_tools
+
+ :param int n_points: number of points to sample
+ :param Callable[..., numpy.ndarray[float]] surface_function: function converting polar
+ coordinates into cartesian coordinates
+ :param int precision: size of grid used to estimate function surface
+ """
+
+ theta, phi, cum_S_t, cum_S_u = _surface_resampling(
+ surface_function, theta_min, theta_max, phi_min, phi_max, precision
+ )
+ # resample along the cumulative surface to uniformize point distribution
+ # equivalent to a multinomial sampling
+ sampled_t = np.random.rand(n_points) * cum_S_t[-1]
+ sampled_u = np.random.rand(n_points) * cum_S_u[-1]
+ sampled_t = interp1d(cum_S_t, theta[0, :])(sampled_t)
+ sampled_u = interp1d(cum_S_u, phi[:, 0])(sampled_u)
+
+ return surface_function(sampled_t, sampled_u)
+
+
+def uniform_surface_wireframe(
+ n_points_1,
+ n_points_2,
+ surface_function,
+ theta_min=0,
+ theta_max=2 * np.pi,
+ phi_min=0,
+ phi_max=np.pi,
+ precision=25,
+):
+ """
+ Uniform-like meshgrid of size (n_point_1, n_points_2) of polar coordinates based on surface
+ estimation.
+ This meshgrid is useful on elliptic surfaces (e.g. sphere).
+ Algorithm based on https://github.com/maxkapur/param_tools
+
+ :param Callable[..., numpy.ndarray[float]] surface_function: function converting polar
+ coordinates into cartesian coordinates
+ :param int precision: size of grid used to estimate function surface
+ """
+
+ theta, phi, cum_S_t, cum_S_u = _surface_resampling(
+ surface_function, theta_min, theta_max, phi_min, phi_max, precision
+ )
+ sampled_t = np.linspace(0, cum_S_t[-1], n_points_1)
+ sampled_u = np.linspace(0, cum_S_u[-1], n_points_2)
+ sampled_t = interp1d(cum_S_t, theta[0, :])(sampled_t)
+ sampled_u = interp1d(cum_S_u, phi[:, 0])(sampled_u)
+ sampled_t, sampled_u = np.meshgrid(sampled_t, sampled_u)
+ return surface_function(sampled_t, sampled_u)
+
+
+def _get_prod_angle_vector(hv, z_versor=np.array([0, 0, 1])):
+ """
+ Calculate the cross product and the arc cosines angle between two vectors.
+
+ :param numpy.ndarray hv: vector to rotate
+ :param numpy.ndarray z_versor: reference vector
+ :return: cross product and arc cosines angle
+ :rtype: Tuple[numpy.ndarray, numpy.ndarray]
+ """
+
+ return np.cross(z_versor, hv), np.arccos(np.dot(hv, z_versor))
+
+
+def _get_rotation_vector(hv, z_versor=np.array([0, 0, 1]), positive_angle=True):
+ """
+ Calculate the rotation vector between two vectors.
+
+ :param numpy.ndarray hv: vector to rotate
+ :param numpy.ndarray z_versor: reference vector
+ :param bool positive_angle: if False, the angle is inverted
+ :return: rotation vector
+ :rtype: scipy.spatial.transform.Rotation
+ """
+
+ perp, angle = _get_prod_angle_vector(hv, z_versor)
+ angle = angle if positive_angle else -angle
+ rot = R.from_rotvec(perp * angle)
+ return rot
+
+
+def _rotate_by_coord(x, y, z, hv, origin, test_hv=False):
+ perp, angle = _get_prod_angle_vector(hv / np.linalg.norm(hv))
+
+ x, y, z = rotate_3d_mesh_by_vec(x, y, z, perp, angle)
+ if test_hv and hv[2] < 0:
+ z = -z
+ return translate_3d_mesh_by_vec(x, y, z, origin)
+
+
+def _get_extrema_after_rot(extrema, origin, top_center):
+ height = np.linalg.norm(top_center - origin)
+ rot = _get_rotation_vector((top_center - origin) / height)
+
+ for i, pt in enumerate(extrema):
+ extrema[i] = rot.apply(pt)
+
+ return np.min(extrema, axis=0) + origin, np.max(extrema, axis=0) + origin
+
+
+def inside_mbox(
+ points: np.ndarray[float],
+ mbb_min: np.ndarray[float],
+ mbb_max: np.ndarray[float],
+):
+ """
+ Check if the points given in input are inside the minimal bounding box.
+
+ :param numpy.ndarray points: An array of 3D points.
+ :param numpy.ndarray mbb_min: 3D point representing the lowest coordinate of the
+ minimal bounding box.
+ :param numpy.ndarray mbb_max: 3D point representing the highest coordinate of the
+ minimal bounding box.
+ :return: A bool np.ndarray specifying whether each point of the input array is inside the
+ minimal bounding box or not.
+ :rtype: numpy.ndarray[bool]
+ """
+ inside = (
+ (points[:, 0] > mbb_min[0])
+ & (points[:, 0] < mbb_max[0])
+ & (points[:, 1] > mbb_min[1])
+ & (points[:, 1] < mbb_max[1])
+ & (points[:, 2] > mbb_min[2])
+ & (points[:, 2] < mbb_max[2])
+ )
+
+ return inside
+
+
+@config.dynamic(attr_name="type", default="shape", auto_classmap=True)
+class GeometricShape(abc.ABC):
+ """
+ Base class for geometric shapes
+ """
+
+ epsilon = config.attr(type=float, required=False, default=1.0e-3)
+ """Tolerance value to compare coordinates."""
+
+ def __init__(self, **kwargs):
+ self.mbb_min, self.mbb_max = self.find_mbb()
+
+ @abc.abstractmethod
+ def find_mbb(self): # pragma: no cover
+ """
+ Compute the minimum bounding box surrounding the shape.
+ """
+ pass
+
+ def check_mbox(self, points: np.ndarray[float]):
+ """
+ Check if the points given in input are inside the minimal bounding box.
+
+ :param numpy.ndarray points: A cloud of points.
+ :return: A bool np.ndarray specifying whether each point of the input array is inside the
+ minimal bounding box or not.
+ :rtype: numpy.ndarray
+ """
+ return inside_mbox(points, self.mbb_min, self.mbb_max)
+
+ @abc.abstractmethod
+ def get_volume(self): # pragma: no cover
+ """
+ Get the volume of the geometric shape.
+ :return: The volume of the geometric shape.
+ :rtype: float
+ """
+ pass
+
+ @abc.abstractmethod
+ def translate(self, t_vector: np.ndarray[float]): # pragma: no cover
+ """
+ Translate the geometric shape by the vector t_vector.
+
+ :param numpy.ndarray t_vector: The displacement vector
+ """
+ pass
+
+ @abc.abstractmethod
+ def rotate(self, r_versor: np.ndarray[float], angle: float): # pragma: no cover
+ """
+ Rotate all the shapes around r_versor, which is a versor passing through the origin,
+ by the specified angle.
+
+ :param r_versor: A versor specifying the rotation axis.
+ :type r_versor: numpy.ndarray[float]
+ :param float angle: the rotation angle, in radians.
+ """
+ pass
+
+ @abc.abstractmethod
+ def generate_point_cloud(self, npoints: int): # pragma: no cover
+ """
+ Generate a point cloud made by npoints points.
+
+ :param int npoints: The number of points to generate.
+ :return: a (npoints x 3) numpy array.
+ :rtype: numpy.ndarray
+ """
+ pass
+
+ @abc.abstractmethod
+ def check_inside(
+ self, points: np.ndarray[float]
+ ) -> np.ndarray[bool]: # pragma: no cover
+ """
+ Check if the points given in input are inside the geometric shape.
+
+ :param numpy.ndarray points: A cloud of points.
+ :return: A bool array with same length as points, containing whether the -ith point is
+ inside the geometric shape or not.
+ :rtype: numpy.ndarray
+ """
+ pass
+
+ @abc.abstractmethod
+ def wireframe_points(self, nb_points_1=30, nb_points_2=30): # pragma: no cover
+ """
+ Generate a wireframe to plot the geometric shape.
+ If a sampling of points is needed (e.g. for sphere), the wireframe is based on a grid
+ of shape (nb_points_1, nb_points_2).
+
+ :param int nb_points_1: number of points sampled along the first dimension
+ :param int nb_points_2: number of points sampled along the second dimension
+ :return: Coordinate components of the wireframe
+ :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]]
+ """
+ pass
+
+
+@config.node
+class ShapesComposition:
+ """
+ A collection of geometric shapes, which can be labelled to distinguish different parts of a
+ neuron.
+ """
+
+ shapes = config.list(
+ type=GeometricShape,
+ required=types.same_size("shapes", "labels", required=True),
+ hint=[{"type": "sphere", "radius": 40.0, "center": [0.0, 0.0, 0.0]}],
+ )
+ """List of GeometricShape that make up the neuron."""
+ labels = config.list(
+ type=types.list(),
+ required=types.same_size("shapes", "labels", required=True),
+ hint=[["soma", "dendrites", "axon"]],
+ )
+ """List of lists of labels associated to each geometric shape."""
+ voxel_size = config.attr(type=float, required=False, default=1.0)
+ """Dimension of the side of a voxel, used to determine how many points must be generated
+ to represent the geometric shape."""
+
+ def __init__(self, **kwargs):
+ # The two corners individuating the minimal bounding box.
+ self._mbb_min = np.array([0.0, 0.0, 0.0])
+ self._mbb_max = np.array([0.0, 0.0, 0.0])
+
+ self.find_mbb()
+
+ def add_shape(self, shape: GeometricShape, labels: List[str]):
+ """
+ Add a geometric shape to the collection
+
+ :param GeometricShape shape: A GeometricShape to add to the collection.
+ :param List[str] labels: A list of labels for the geometric shape to add.
+ """
+ # Update mbb
+ if len(self._shapes) == 0:
+ self._mbb_min = np.copy(shape.mbb_min)
+ self._mbb_max = np.copy(shape.mbb_max)
+ else:
+ self._mbb_min = np.minimum(self._mbb_min, shape.mbb_min)
+ self._mbb_max = np.maximum(self._mbb_max, shape.mbb_max)
+ self._shapes.append(shape)
+ self._labels.append(labels)
+
+ def filter_by_labels(self, labels: List[str]) -> ShapesComposition:
+ """
+ Filter the collection of shapes, returning only the ones corresponding the given labels.
+
+ :param List[str] labels: A list of labels.
+ :return: A new ShapesComposition object containing only the shapes labelled as specified.
+ :rtype: ShapesComposition
+ """
+ result = ShapesComposition(dict(voxel_size=self.voxel_size, labels=[], shapes=[]))
+ selected_id = np.where(np.isin(labels, self._labels))[0]
+ result._shapes = [self._shapes[i].__copy__() for i in selected_id]
+ result._labels = [self._labels[i].copy() for i in selected_id]
+ result.mbb_min, result.mbb_max = result.find_mbb()
+ return result
+
+ def translate(self, t_vec: np.ndarray[float]):
+ """
+ Translate all the shapes in the collection by the vector t_vec. It also automatically
+ translate the minimal bounding box.
+
+ :param numpy.ndarray t_vec: The displacement vector.
+ """
+ for shape in self._shapes:
+ shape.translate(t_vec)
+ self._mbb_min += t_vec
+ self._mbb_max += t_vec
+
+ def get_volumes(self) -> List[float]:
+ """
+ Compute the volumes of all the shapes.
+
+ :rtype: List[float]
+ """
+ return [shape.get_volume() for shape in self._shapes]
+
+ def get_mbb_min(self):
+ """
+ Returns the bottom corner of the minimum bounding box containing the collection of shapes.
+
+ :return: The bottom corner individuating the minimal bounding box of the shapes collection.
+ :rtype: numpy.ndarray[float]
+ """
+ return self._mbb_min
+
+ def get_mbb_max(self):
+ """
+ Returns the top corner of the minimum bounding box containing the collection of shapes.
+
+ :return: The top corner individuating the minimal bounding box of the shapes collection.
+ :rtype: numpy.ndarray[float]
+ """
+ return self._mbb_max
+
+ def find_mbb(self) -> Tuple[np.ndarray[float], np.ndarray[float]]:
+ """
+ Compute the minimal bounding box containing the collection of shapes.
+
+ :return: The two corners individuating the minimal bounding box of the shapes collection.
+ :rtype: Tuple(numpy.ndarray[float], numpy.ndarray[float])
+ """
+ mins = np.empty([len(self._shapes), 3])
+ maxs = np.empty([len(self._shapes), 3])
+ for i, shape in enumerate(self._shapes):
+ mins[i, :] = shape.mbb_min
+ maxs[i, :] = shape.mbb_max
+ self._mbb_min = np.min(mins, axis=0) if len(self._shapes) > 0 else np.zeros(3)
+ self._mbb_max = np.max(maxs, axis=0) if len(self._shapes) > 0 else np.zeros(3)
+ return self._mbb_min, self._mbb_max
+
+ def compute_n_points(self) -> List[int]:
+ """
+ Compute the number of points to generate in a point cloud, using the dimension of the voxel
+ specified in self._voxel_size.
+
+ :return: The number of points to generate.
+ :rtype: numpy.ndarray[int]
+ """
+ return [int(shape.get_volume() // self.voxel_size**3) for shape in self._shapes]
+
+ def generate_point_cloud(self) -> np.ndarray[float] | None:
+ """
+ Generate a point cloud. The number of points to generate is determined automatically using
+ the voxel size.
+
+ :return: A numpy.ndarray containing the 3D points of the cloud. If there are no shapes in
+ the collection, it returns None.
+ :rtype: numpy.ndarray[float] | None
+ """
+ if len(self._shapes) != 0:
+ return np.concatenate(
+ [
+ shape.generate_point_cloud(numpts)
+ for shape, numpts in zip(self._shapes, self.compute_n_points())
+ ]
+ )
+ else:
+ return None
+
+ def generate_wireframe(
+ self,
+ nb_points_1=30,
+ nb_points_2=30,
+ ) -> Tuple[List, List, List] | None:
+ """
+ Generate the wireframes of a collection of shapes.
+ If a sampling of points is needed for certain shapes (e.g. for sphere), their wireframe
+ is based on a grid of shape (nb_points_1, nb_points_2).
+
+ :param int nb_points_1: number of points sampled along the first dimension
+ :param int nb_points_2: number of points sampled along the second dimension
+ :return: The x,y,z coordinates of the wireframe of each shape.
+ :rtype: Tuple[List[numpy.ndarray[numpy.ndarray[float]]]] | None
+ """
+ if len(self._shapes) != 0:
+ x = []
+ y = []
+ z = []
+ for shape in self._shapes:
+ # For each shape, the shape of the wireframe is different, so we need to append them
+ # manually
+ xt, yt, zt = shape.wireframe_points(
+ nb_points_1=nb_points_1, nb_points_2=nb_points_2
+ )
+ x.append(xt)
+ y.append(yt)
+ z.append(zt)
+ return x, y, z
+ return None
+
+ def inside_mbox(self, points: np.ndarray[float]) -> np.ndarray[bool]:
+ """
+ Check if the points given in input are inside the minimal bounding box of the collection.
+
+ :param numpy.ndarray points: An array of 3D points.
+ :return: A bool np.ndarray specifying whether each point of the input array is inside the
+ minimal bounding box of the collection.
+ :rtype: numpy.ndarray[bool]
+ """
+ return inside_mbox(points, self._mbb_min, self._mbb_max)
+
+ def inside_shapes(self, points: np.ndarray[float]) -> np.ndarray[bool] | None:
+ """
+ Check if the points given in input are inside at least in one of the shapes of the
+ collection.
+
+ :param numpy.ndarray points: An array of 3D points.
+ :return: A bool numpy.ndarray specifying whether each point of the input array is inside the
+ collection of shapes or not.
+ :rtype: numpy.ndarray[bool]
+ """
+ if len(self._shapes) != 0:
+ is_inside = np.full(len(points), 0, dtype=bool)
+ for shape in self._shapes:
+ tmp = shape.check_mbox(points)
+ if np.any(tmp):
+ is_inside = is_inside | shape.check_inside(points)
+ return is_inside
+ else:
+ return None
+
+
+@config.node
+class Ellipsoid(GeometricShape, classmap_entry="ellipsoid"):
+ """
+ An ellipsoid, described in cartesian coordinates.
+ """
+
+ origin = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
+ )
+ """The coordinates of the center of the ellipsoid."""
+ lambdas = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[1.0, 0.5, 2.0]
+ )
+ """The length of the three semi-axes."""
+
+ @config.property(type=types.ndarray(), required=True)
+ def v0(self):
+ """The versor on which the first semi-axis lies."""
+ return self._v0
+
+ @v0.setter
+ def v0(self, value):
+ self._v0 = np.copy(value) / np.linalg.norm(value)
+
+ @config.property(type=types.ndarray(), required=True)
+ def v1(self):
+ """The versor on which the second semi-axis lies."""
+ return self._v1
+
+ @v1.setter
+ def v1(self, value):
+ self._v1 = np.copy(value) / np.linalg.norm(value)
+
+ @config.property(type=types.ndarray(), required=True)
+ def v2(self):
+ """The versor on which the third semi-axis lies."""
+ return self._v2
+
+ @v2.setter
+ def v2(self, value):
+ self._v2 = np.copy(value) / np.linalg.norm(value)
+
+ def find_mbb(self):
+ # Find the minimum bounding box, to avoid computing it every time
+ extrema = (
+ np.array(
+ [
+ self.lambdas[0] * self.v0,
+ -self.lambdas[0] * self.v0,
+ self.lambdas[1] * self.v1,
+ -self.lambdas[1] * self.v1,
+ self.lambdas[2] * self.v2,
+ -self.lambdas[2] * self.v2,
+ ]
+ )
+ + self.origin
+ )
+ mbb_min = np.min(extrema, axis=0)
+ mbb_max = np.max(extrema, axis=0)
+ return mbb_min, mbb_max
+
+ def get_volume(self):
+ return np.pi * self.lambdas[0] * self.lambdas[1] * self.lambdas[2]
+
+ def translate(self, t_vector: np.ndarray):
+ self.origin += t_vector
+ self.mbb_min += t_vector
+ self.mbb_max += t_vector
+
+ def rotate(self, r_versor: np.ndarray[float], angle: float):
+ rot = R.from_rotvec(r_versor * angle)
+ self.v0 = rot.apply(self.v0)
+ self.v1 = rot.apply(self.v1)
+ self.v2 = rot.apply(self.v2)
+
+ def surface_point(self, theta, phi):
+ """
+ Convert polar coordinates into their 3D location on the ellipsoid surface.
+
+ :param float|numpy.ndarray[float] theta: first polar coordinate in [0; 2*np.pi]
+ :param float|numpy.ndarray[float] phi: second polar coordinate in [0; np.pi]
+ :return: surface coordinates
+ :rtype: float|numpy.ndarray[float]
+ """
+ return np.array(
+ [
+ self.lambdas[0] * np.cos(theta) * np.sin(phi),
+ self.lambdas[1] * np.sin(theta) * np.sin(phi),
+ self.lambdas[2] * np.cos(phi),
+ ]
+ )
+
+ def generate_point_cloud(self, npoints: int):
+ sampling = uniform_surface_sampling(npoints, self.surface_point)
+ sampling = sampling.T * np.random.rand(npoints, 3) # sample within the shape
+
+ # Rotate the ellipse
+ rmat = np.array([self.v0, self.v1, self.v2]).T
+ sampling = sampling.dot(rmat)
+ sampling = sampling + self.origin
+ return sampling
+
+ def check_inside(self, points: np.ndarray[float]):
+ # Check if the quadratic form associated to the ellipse is less than 1 at a point
+ diff = points - self.origin
+ vmat = np.array([self.v0, self.v1, self.v2])
+ diag = np.diag(1 / self.lambdas**2)
+ qmat = vmat.dot(diag).dot(vmat)
+ quad_prod = np.diagonal(diff.dot(qmat.dot(diff.T)))
+
+ # Check if the points are inside the ellipsoid
+ inside_points = quad_prod < 1
+ return inside_points
+
+ def wireframe_points(self, nb_points_1=30, nb_points_2=30):
+ # Generate an ellipse orientated along x,y,z
+ x, y, z = uniform_surface_wireframe(nb_points_1, nb_points_2, self.surface_point)
+ # Rotate the ellipse
+ rmat = np.array([self.v0, self.v1, self.v2]).T
+ x, y, z = rotate_3d_mesh_by_rot_mat(x, y, z, rmat)
+ return translate_3d_mesh_by_vec(x, y, z, self.origin)
+
+
+@config.node
+class Cone(GeometricShape, classmap_entry="cone"):
+ """
+ A cone, described in cartesian coordinates.
+ """
+
+ apex = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0]
+ )
+ """The coordinates of the apex of the cone."""
+ origin = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
+ )
+ """The coordinates of the center of the cone's base."""
+ radius = config.attr(type=float, required=False, default=1.0)
+ """The radius of the base circle."""
+
+ def find_mbb(self):
+ # Vectors identifying half of the sides of the base rectangle in xy
+ u = np.array([self.radius, 0, 0])
+ v = np.array([0, self.radius, 0])
+
+ # Find the rotation angle and axis
+ hv = self.origin - self.apex
+ rot = _get_rotation_vector(hv / np.linalg.norm(hv))
+
+ # Rotated vectors of the box
+ v1 = rot.apply(u)
+ v2 = rot.apply(v)
+ v3 = self.origin - self.apex
+
+ # Coordinates identifying the minimal bounding box
+ minima = np.min([v1, v2, v3, -v1, -v2], axis=0)
+ maxima = np.max([v1, v2, v3, -v1, -v2], axis=0)
+ return minima, maxima
+
+ def get_volume(self):
+ h = np.linalg.norm(self.apex - self.origin)
+ b = np.pi * self.radius * self.radius
+ return b * h / 3
+
+ def translate(self, t_vector):
+ self.origin += t_vector
+ self.apex += t_vector
+ self.mbb_min += t_vector
+ self.mbb_max += t_vector
+
+ def rotate(self, r_versor: np.ndarray[float], angle: float):
+ rot = R.from_rotvec(r_versor * angle)
+ self.apex = rot.apply(self.apex)
+
+ def generate_point_cloud(self, npoints: int):
+ theta = np.pi * 2.0 * np.random.rand(npoints)
+ rand_a = np.random.rand(npoints)
+ rand_b = np.random.rand(npoints)
+
+ # Height vector
+ hv = self.origin - self.apex
+ cloud = np.full((npoints, 3), 0, dtype=float)
+
+ # Generate a cone with the apex in the origin and the origin at (0,0,1)
+ cloud[:, 0] = (self.radius * rand_a * np.cos(theta)) * rand_b
+ cloud[:, 1] = self.radius * rand_a * np.sin(theta)
+ cloud[:, 2] = rand_a * np.linalg.norm(hv)
+
+ # Rotate the cone: Find the axis of rotation and compute the angle
+ perp, angle = _get_prod_angle_vector(hv / np.linalg.norm(hv))
+
+ if hv[2] < 0:
+ cloud[:, 2] = -cloud[:, 2]
+ rot = R.from_rotvec(perp * angle)
+ cloud = rot.apply(cloud)
+
+ # Translate the cone
+ cloud = cloud + self.apex
+ return cloud
+
+ def check_inside(self, points: np.ndarray[float]):
+ # Find the vector w of the height.
+ h_vector = self.origin - self.apex
+ height = np.linalg.norm(h_vector)
+ h_vector /= height
+ # Center the points
+ pts = points - self.apex
+
+ # Rotate back to xyz
+ rot = _get_rotation_vector(h_vector)
+ rot_pts = rot.apply(pts)
+
+ # Find the angle between the points and the apex
+ apex_angles = np.arccos(
+ np.dot((rot_pts / np.linalg.norm(rot_pts, axis=1)[..., np.newaxis]), h_vector)
+ )
+ # Compute the cone angle
+ cone_angle = np.arctan(self.radius / height)
+
+ # Select the points inside the cone
+ inside_points = (
+ (apex_angles < cone_angle + self.epsilon)
+ & (rot_pts[:, 2] > np.min([self.origin[2], self.apex[2]]) - self.epsilon)
+ & (rot_pts[:, 2] < np.max([self.origin[2], self.apex[2]]) + self.epsilon)
+ )
+ return inside_points
+
+ def wireframe_points(self, nb_points_1=30, nb_points_2=30):
+ # Set up the grid in polar coordinates
+ theta = np.linspace(0, 2 * np.pi, nb_points_1)
+ r = np.linspace(0, self.radius, nb_points_2)
+ theta, r = np.meshgrid(theta, r)
+
+ # Height vector
+ hv = np.array(self.origin) - np.array(self.apex)
+ height = np.linalg.norm(hv)
+ # angle = np.arctan(height/self.radius)
+
+ # Generate a cone with the apex in the origin and the center at (0,0,1)
+ x = r * np.cos(theta)
+ y = r * np.sin(theta)
+ z = r * height / self.radius
+
+ # Rotate the cone
+ return _rotate_by_coord(x, y, z, hv, self.apex, test_hv=True)
+
+
+@config.node
+class Cylinder(GeometricShape, classmap_entry="cylinder"):
+ """
+ A cylinder, described in cartesian coordinates.
+ """
+
+ origin = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
+ )
+ """The coordinates of the center of the bottom circle of the cylinder."""
+ top_center = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 2.0, 0.0]
+ )
+ """The coordinates of the center of the top circle of the cylinder."""
+ radius = config.attr(type=float, required=False, default=1.0)
+ """The radius of the base circle."""
+
+ def find_mbb(self):
+ height = np.linalg.norm(self.top_center - self.origin)
+ # Extrema of the xyz standard cyl
+ extrema = [
+ np.array([-self.radius, -self.radius, 0.0]),
+ np.array([-self.radius, self.radius, 0.0]),
+ np.array([self.radius, -self.radius, 0.0]),
+ np.array([self.radius, self.radius, 0.0]),
+ np.array([self.radius, self.radius, height]),
+ np.array([-self.radius, self.radius, height]),
+ np.array([self.radius, -self.radius, height]),
+ np.array([-self.radius, -self.radius, height]),
+ ]
+
+ # Rotate the cylinder
+
+ return _get_extrema_after_rot(extrema, self.origin, self.top_center)
+
+ def get_volume(self):
+ h = np.linalg.norm(self.top_center - self.origin)
+ b = np.pi * self.radius * self.radius
+ return b * h
+
+ def translate(self, t_vector: np.ndarray[float]):
+ self.origin += t_vector
+ self.top_center += t_vector
+ self.mbb_min += t_vector
+ self.mbb_max += t_vector
+
+ def rotate(self, r_versor: np.ndarray[float], angle: float):
+ rot = R.from_rotvec(r_versor * angle)
+ # rotation according to bottom center
+ self.top_center = rot.apply(self.top_center)
+
+ def generate_point_cloud(self, npoints: int):
+ # Generate an ellipse orientated along x,y,z
+ cloud = np.full((npoints, 3), 0, dtype=float)
+ theta = np.pi * 2.0 * np.random.rand(npoints)
+ rand = np.random.rand(npoints, 3)
+ height = np.linalg.norm(self.top_center - self.origin)
+
+ # Generate an ellipsoid centered at the origin, with the semiaxes on x,y,z
+ cloud[:, 0] = self.radius * np.cos(theta)
+ cloud[:, 1] = self.radius * np.sin(theta)
+ cloud[:, 2] = height
+ cloud = cloud * rand
+
+ # Rotate the cylinder
+ hv = (self.top_center - self.origin) / height
+ rot = _get_rotation_vector(hv / np.linalg.norm(hv))
+ cloud = rot.apply(cloud)
+ cloud = cloud + self.origin
+ return cloud
+
+ def check_inside(self, points: np.ndarray[float]):
+ # Translate back to origin
+ pts = points - self.origin
+
+ # Rotate back to xyz
+ height = np.linalg.norm(self.top_center - self.origin)
+ rot = _get_rotation_vector(
+ (self.top_center - self.origin) / height, positive_angle=False
+ )
+ rot_pts = rot.apply(pts)
+ # Check for intersections
+ inside_points = (
+ (rot_pts[:, 2] < height + self.epsilon)
+ & (rot_pts[:, 2] > -self.epsilon)
+ & (
+ rot_pts[:, 0] * rot_pts[:, 0] + rot_pts[:, 1] * rot_pts[:, 1]
+ < self.radius**2 + self.epsilon
+ )
+ )
+ return inside_points
+
+ def wireframe_points(self, nb_points_1=30, nb_points_2=30):
+ # Set up the grid in polar coordinates
+ theta = np.linspace(0, 2 * np.pi, nb_points_1)
+
+ # Height vector
+ hv = np.array(self.origin) - np.array(self.top_center)
+ height = np.linalg.norm(hv)
+
+ h = np.linspace(0, height, nb_points_2)
+ theta, h = np.meshgrid(theta, h)
+
+ x = self.radius * np.cos(theta)
+ y = self.radius * np.sin(theta)
+ z = h
+
+ # Rotate the cylinder
+ hv = (self.top_center - self.origin) / height
+ return _rotate_by_coord(x, y, z, hv, self.origin)
+
+
+@config.node
+class Sphere(GeometricShape, classmap_entry="sphere"):
+ """
+ A sphere, described in cartesian coordinates.
+ """
+
+ origin = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
+ )
+ """The coordinates of the center of the sphere."""
+ radius = config.attr(type=float, required=False, default=1.0)
+ """The radius of the sphere."""
+
+ def find_mbb(self):
+ # Find the minimum bounding box, to avoid computing it every time
+ mbb_min = np.array([-self.radius, -self.radius, -self.radius]) + self.origin
+ mbb_max = np.array([self.radius, self.radius, self.radius]) + self.origin
+ return mbb_min, mbb_max
+
+ def get_volume(self):
+ return np.pi * 4.0 / 3.0 * np.power(self.radius, 3)
+
+ def translate(self, t_vector: np.ndarray[float]):
+ self.origin += t_vector
+ self.mbb_min += t_vector
+ self.mbb_max += t_vector
+
+ def rotate(self, r_versor: np.ndarray[float], angle: float): # pragma: no cover
+ # It's a sphere, it's invariant under rotation!
+ pass
+
+ def surface_function(self, theta, phi):
+ """
+ Convert polar coordinates into their 3D location on the sphere surface.
+
+ :param float|numpy.ndarray[float] theta: first polar coordinate in [0; 2*np.pi]
+ :param float|numpy.ndarray[float] phi: second polar coordinate in [0; np.pi]
+ :return: surface coordinates
+ :rtype: float|numpy.ndarray[float]
+ """
+ return np.array(
+ [
+ self.radius * np.cos(theta) * np.sin(phi),
+ self.radius * np.sin(theta) * np.sin(phi),
+ self.radius * np.cos(phi),
+ ]
+ )
+
+ def generate_point_cloud(self, npoints: int):
+ # Generate a sphere centered at the origin.
+ cloud = uniform_surface_sampling(npoints, self.surface_function)
+ cloud = cloud.T * np.random.rand(npoints, 3) # sample within the shape
+
+ cloud = cloud + self.origin
+
+ return cloud
+
+ def check_inside(self, points: np.ndarray[float]):
+ # Translate the points, bringing the origin to the center of the sphere,
+ # then check the inequality defining the sphere
+ pts_centered = points - self.origin
+ lhs = np.linalg.norm(pts_centered, axis=1)
+ inside_points = lhs < self.radius + self.epsilon
+ return inside_points
+
+ def wireframe_points(self, nb_points_1=30, nb_points_2=30):
+ x, y, z = uniform_surface_wireframe(
+ nb_points_1, nb_points_2, self.surface_function
+ )
+ return translate_3d_mesh_by_vec(x, y, z, self.origin)
+
+
+@config.node
+class Cuboid(GeometricShape, classmap_entry="cuboid"):
+ """
+ A rectangular parallelepiped, described in cartesian coordinates.
+ """
+
+ origin = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
+ )
+ """The coordinates of the center of the barycenter of the bottom rectangle."""
+ top_center = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0]
+ )
+ """The coordinates of the center of the barycenter of the top rectangle."""
+ side_length_1 = config.attr(type=float, required=False, default=1.0)
+ """Length of one side of the base rectangle."""
+ side_length_2 = config.attr(type=float, required=False, default=1.0)
+ """Length of the other side of the base rectangle."""
+
+ def find_mbb(self):
+ # Extrema of the cuboid centered at the origin
+ extrema = [
+ np.array([-self.side_length_1 / 2.0, -self.side_length_2 / 2.0, 0.0]),
+ np.array([self.side_length_1 / 2.0, self.side_length_2 / 2.0, 0.0]),
+ np.array([-self.side_length_1 / 2.0, self.side_length_2 / 2.0, 0.0]),
+ np.array([self.side_length_1 / 2.0, -self.side_length_2 / 2.0, 0.0]),
+ np.array(
+ [
+ self.side_length_1 / 2.0 + self.top_center[0],
+ self.side_length_2 / 2.0 + self.top_center[1],
+ self.top_center[2],
+ ]
+ ),
+ np.array(
+ [
+ -self.side_length_1 / 2.0 + self.top_center[0],
+ self.side_length_2 / 2.0 + self.top_center[1],
+ self.top_center[2],
+ ]
+ ),
+ np.array(
+ [
+ -self.side_length_1 / 2.0 + self.top_center[0],
+ -self.side_length_2 / 2.0 + self.top_center[1],
+ self.top_center[2],
+ ]
+ ),
+ np.array(
+ [
+ self.side_length_1 / 2.0 + self.top_center[0],
+ -self.side_length_2 / 2.0 + self.top_center[1],
+ self.top_center[2],
+ ]
+ ),
+ ]
+
+ # Rotate the cuboid
+ return _get_extrema_after_rot(extrema, self.origin, self.top_center)
+
+ def get_volume(self):
+ h = np.linalg.norm(self.top_center - self.origin)
+ return h * self.side_length_1 * self.side_length_2
+
+ def translate(self, t_vector: np.ndarray[float]):
+ self.origin += t_vector
+ self.top_center += t_vector
+ self.mbb_min += t_vector
+ self.mbb_max += t_vector
+
+ def rotate(self, r_versor: np.ndarray[float], angle: float):
+ rot = R.from_rotvec(r_versor * angle)
+ self.top_center = rot.apply(self.top_center)
+
+ def generate_point_cloud(self, npoints: int):
+ # Generate a unit cuboid whose base rectangle has the barycenter in the origin
+ rand = np.random.rand(npoints, 3)
+ rand[:, 0] = rand[:, 0] - 0.5
+ rand[:, 1] = rand[:, 1] - 0.5
+
+ # Scale the sides of the cuboid
+ height = np.linalg.norm(self.top_center - self.origin)
+ rand[:, 0] = rand[:, 0] * self.side_length_1 / 2.0
+ rand[:, 1] = rand[:, 1] * self.side_length_2 / 2.0
+ rand[:, 2] = rand[:, 2] * height
+
+ # Rotate the cuboid
+ rot = _get_rotation_vector((self.top_center - self.origin) / height)
+ cloud = rot.apply(rand)
+
+ # Translate the cuboid
+ cloud = cloud + self.origin
+ return cloud
+
+ def check_inside(self, points: np.ndarray[float]):
+ # Translate back to origin
+ pts = points - self.origin
+
+ # Rotate back to xyz
+ height = np.linalg.norm(self.top_center - self.origin)
+ rot = _get_rotation_vector(
+ (self.top_center - self.origin) / height, positive_angle=False
+ )
+ rot_pts = rot.apply(pts)
+
+ # Check for intersections
+ inside_points = (
+ (rot_pts[:, 2] < height)
+ & (rot_pts[:, 2] > 0.0)
+ & (rot_pts[:, 0] < self.side_length_1)
+ & (rot_pts[:, 0] > -self.side_length_1)
+ & (rot_pts[:, 1] < self.side_length_2)
+ & (rot_pts[:, 1] > -self.side_length_2)
+ )
+ return inside_points
+
+ def wireframe_points(self, **kwargs):
+ a = self.side_length_1 / 2.0
+ b = self.side_length_2 / 2.0
+ c = np.linalg.norm(self.top_center - self.origin)
+
+ x = np.array(
+ [
+ [-a, a, a, -a], # x coordinate of points in bottom surface
+ [-a, a, a, -a], # x coordinate of points in upper surface
+ [-a, a, -a, a], # x coordinate of points in outside surface
+ [-a, a, -a, a],
+ ]
+ ) # x coordinate of points in inside surface
+ y = np.array(
+ [
+ [-b, -b, b, b], # y coordinate of points in bottom surface
+ [-b, -b, b, b], # y coordinate of points in upper surface
+ [-b, -b, -b, -b], # y coordinate of points in outside surface
+ [b, b, b, b],
+ ]
+ ) # y coordinate of points in inside surface
+ z = np.array(
+ [
+ [0.0, 0.0, 0.0, 0.0], # z coordinate of points in bottom surface
+ [c, c, c, c], # z coordinate of points in upper surface
+ [0.0, 0.0, c, c], # z coordinate of points in outside surface
+ [0.0, 0.0, c, c],
+ ]
+ ) # z coordinate of points in inside surface
+
+ # Rotate the cuboid
+ hv = (self.top_center - self.origin) / c
+ return _rotate_by_coord(x, y, z, hv, self.origin)
+
+
+@config.node
+class Parallelepiped(GeometricShape, classmap_entry="parallelepiped"):
+ """
+ A generic parallelepiped, described by the vectors (following the right-hand orientation) of the
+ sides in cartesian coordinates
+ """
+
+ origin = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0]
+ )
+ """The coordinates of the left-bottom edge."""
+ side_vector_1 = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[1.0, 0.0, 0.0]
+ )
+ """The first vector identifying the parallelepiped (using the right-hand orientation: the
+ thumb)."""
+ side_vector_2 = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0]
+ )
+ """The second vector identifying the parallelepiped (using the right-hand orientation: the
+ index)."""
+ side_vector_3 = config.attr(
+ type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 1.0]
+ )
+ """The third vector identifying the parallelepiped (using the right-hand orientation: the
+ middle finger)."""
+
+ def find_mbb(self):
+ extrema = np.vstack(
+ [
+ np.array([0.0, 0.0, 0.0]),
+ self.side_vector_1 + self.side_vector_2 + self.side_vector_3,
+ ]
+ )
+ maxima = np.max(extrema, axis=0) + self.origin
+ minima = np.min(extrema, axis=0) + self.origin
+ return minima, maxima
+
+ def get_volume(self):
+ vol = np.dot(self.side_vector_3, np.cross(self.side_vector_1, self.side_vector_2))
+ return vol
+
+ def translate(self, t_vector: np.ndarray[float]):
+ self.origin += t_vector
+ self.mbb_min += t_vector
+ self.mbb_max += t_vector
+
+ def rotate(self, r_versor: np.ndarray[float], angle: float):
+ rot = R.from_rotvec(r_versor * angle)
+ # self.center = rot.apply(self.center)
+ self.side_vector_1 = rot.apply(self.side_vector_1)
+ self.side_vector_2 = rot.apply(self.side_vector_2)
+ self.side_vector_3 = rot.apply(self.side_vector_3)
+
+ def generate_point_cloud(self, npoints: int):
+ # Generate a linear combination of points in the volume
+ cloud = np.full((npoints, 3), 0, dtype=float)
+ rand = np.random.rand(npoints, 3)
+ for i in range(npoints):
+ cloud[i] = (
+ rand[i, 0] * self.side_vector_1
+ + rand[i, 1] * self.side_vector_2
+ + rand[i, 2] * self.side_vector_3
+ )
+ cloud += self.origin
+ return cloud
+
+ def check_inside(self, points: np.ndarray[float]):
+ # Translate back to origin
+ pts = points - self.origin
+
+ # Rotate back to xyz
+ height = np.linalg.norm(self.side_vector_3)
+ rot = _get_rotation_vector(hv=self.side_vector_3 / height, positive_angle=True)
+ rot_pts = rot.apply(pts)
+
+ # Compute the Fourier components wrt to the vectors identifying the parallelepiped
+
+ v1_norm = np.linalg.norm(self.side_vector_1)
+ comp1 = rot_pts.dot(self.side_vector_1) / v1_norm
+ v2_norm = np.linalg.norm(self.side_vector_2)
+ comp2 = rot_pts.dot(self.side_vector_2) / v2_norm
+ v3_norm = np.linalg.norm(self.side_vector_3)
+ comp3 = rot_pts.dot(self.side_vector_3) / v3_norm
+
+ # The points are inside the parallelepiped if and only if all the Fourier components
+ # are between 0 and the norm of sides of the parallelepiped
+ inside_points = (
+ (comp1 > 0.0)
+ & (comp1 < v1_norm)
+ & (comp2 > 0.0)
+ & (comp2 < v2_norm)
+ & (comp3 > 0.0)
+ & (comp3 < v3_norm)
+ )
+ return inside_points
+
+ def wireframe_points(self, **kwargs):
+ va = self.side_vector_1
+ vb = self.side_vector_2
+ vc = self.side_vector_3
+
+ a = va
+ b = va + vb
+ c = vb
+ d = np.array([0.0, 0.0, 0.0])
+ e = va + vc
+ f = va + vb + vc
+ g = vb + vc
+ h = vc
+
+ x = np.array(
+ [
+ [a[0], b[0], c[0], d[0]],
+ [e[0], f[0], g[0], h[0]],
+ [a[0], b[0], f[0], e[0]],
+ [d[0], c[0], g[0], h[0]],
+ ]
+ )
+ y = np.array(
+ [
+ [a[1], b[1], c[1], d[1]],
+ [e[1], f[1], g[1], h[1]],
+ [a[1], b[1], f[1], e[1]],
+ [d[1], c[1], g[1], h[1]],
+ ]
+ )
+ z = np.array(
+ [
+ [a[2], b[2], c[2], d[2]],
+ [e[2], f[2], g[2], h[2]],
+ [a[2], b[2], f[2], e[2]],
+ [d[2], c[2], g[2], h[2]],
+ ]
+ )
+
+ return x + self.origin[0], y + self.origin[1], z + self.origin[2]
diff --git a/bsb/connectivity/geometric/morphology_shape_intersection.py b/bsb/connectivity/geometric/morphology_shape_intersection.py
new file mode 100644
index 00000000..97e93146
--- /dev/null
+++ b/bsb/connectivity/geometric/morphology_shape_intersection.py
@@ -0,0 +1,142 @@
+import numpy as np
+
+from ... import config
+from ...config import types
+from .. import ConnectionStrategy
+from .shape_morphology_intersection import _create_geometric_conn_arrays
+from .shape_shape_intersection import ShapeHemitype
+
+
+def overlap_boxes(box1_min, box1_max, box2_min, box2_max):
+ """
+ Check if two minimal bounding box are overlapping.
+
+ :param numpy.ndarray box1_min: 3D point representing the lowest coordinate of the
+ minimal bounding box.
+ :param numpy.ndarray box1_max: 3D point representing the highest coordinate of the
+ minimal bounding box.
+ :param numpy.ndarray box2_min: 3D point representing the lowest coordinate of the
+ minimal bounding box.
+ :param numpy.ndarray box2_max: 3D point representing the highest coordinate of the
+ minimal bounding box.
+ """
+ return np.all((box1_max >= box2_min) & (box2_max >= box1_min))
+
+
+@config.node
+class MorphologyToShapeIntersection(ConnectionStrategy):
+ postsynaptic = config.attr(type=ShapeHemitype, required=True)
+ affinity = config.attr(type=types.fraction(), required=True, hint=0.1)
+ """Ratio of apositions to keep over the total number of contact points"""
+ pruning_ratio = config.attr(type=types.fraction(), required=True, hint=0.1)
+ """Ratio of conections to keep over the total number of apositions"""
+
+ def get_region_of_interest(self, chunk):
+ lpre, upre = self.presynaptic._get_rect_ext(tuple(chunk.dimensions))
+ post_chunks = self.postsynaptic.get_all_chunks()
+ tree = self.postsynaptic._get_shape_boxtree(
+ post_chunks,
+ )
+ pre_mbb = [
+ np.concatenate(
+ [
+ (lpre + chunk) * chunk.dimensions,
+ np.maximum((upre + chunk), (lpre + chunk) + 1) * chunk.dimensions,
+ ]
+ )
+ ]
+ return [post_chunks[i] for i in tree.query(pre_mbb, unique=True)]
+
+ def connect(self, pre, post):
+ for pre_ps in pre.placement:
+ for post_ps in post.placement:
+ self._connect_type(pre_ps, post_ps)
+
+ def _connect_type(self, pre_ps, post_ps):
+ pre_pos = pre_ps.load_positions()
+ post_pos = post_ps.load_positions()
+
+ post_shapes = self.postsynaptic.shapes_composition.__copy__()
+
+ to_connect_pre = np.empty([0, 3], dtype=int)
+ to_connect_post = np.empty([0, 3], dtype=int)
+
+ morpho_set = pre_ps.load_morphologies()
+ pre_morphos = morpho_set.iter_morphologies(cache=True, hard_cache=True)
+
+ for pre_id, (pre_coord, morpho) in enumerate(zip(pre_pos, pre_morphos)):
+ # Get the branches
+ branches = morpho.get_branches()
+
+ # Build ids array from the morphology
+ pre_points_ids, pre_morpho_coord = _create_geometric_conn_arrays(
+ branches, pre_id, pre_coord
+ )
+ pre_min_mbb, pre_max_mbb = morpho.bounds
+ pre_min_mbb += pre_coord
+ pre_max_mbb += pre_coord
+
+ tmp_pre_selection = np.full(
+ [len(post_pos) * int(len(pre_morpho_coord) * self.affinity), 3],
+ -1,
+ dtype=int,
+ )
+ tmp_post_selection = np.full(
+ [len(post_pos) * int(len(pre_morpho_coord) * self.affinity), 3],
+ -1,
+ dtype=int,
+ )
+ ptr = 0
+
+ for post_id, post_coord in enumerate(post_pos):
+ post_shapes.translate(post_coord)
+ if overlap_boxes(
+ post_shapes.get_mbb_min(),
+ post_shapes.get_mbb_max(),
+ pre_min_mbb,
+ pre_max_mbb,
+ ):
+ mbb_check = post_shapes.inside_mbox(pre_morpho_coord)
+ if np.any(mbb_check):
+ inside_pts = post_shapes.inside_shapes(
+ pre_morpho_coord[mbb_check]
+ )
+ if np.any(inside_pts):
+ local_selection = (pre_points_ids[mbb_check])[inside_pts]
+ if self.affinity < 1 and len(local_selection) > 0:
+ nb_sources = np.max(
+ [
+ 1,
+ int(
+ np.floor(self.affinity * len(local_selection))
+ ),
+ ]
+ )
+ chosen_targets = np.random.choice(
+ local_selection.shape[0], nb_sources
+ )
+ local_selection = local_selection[chosen_targets, :]
+
+ selected_count = len(local_selection)
+ if selected_count > 0:
+ tmp_pre_selection[ptr : ptr + selected_count, 0] = pre_id
+ tmp_post_selection[ptr : ptr + selected_count, 0] = (
+ post_id
+ )
+ ptr += selected_count
+
+ post_shapes.translate(-post_coord)
+
+ to_connect_pre = np.vstack([to_connect_pre, tmp_pre_selection[:ptr]])
+ to_connect_post = np.vstack([to_connect_post, tmp_post_selection[:ptr]])
+
+ if self.pruning_ratio < 1 and len(to_connect_pre) > 0:
+ ids_to_select = np.random.choice(
+ len(to_connect_pre),
+ int(np.floor(self.pruning_ratio * len(to_connect_pre))),
+ replace=False,
+ )
+ to_connect_pre = to_connect_pre[ids_to_select]
+ to_connect_post = to_connect_post[ids_to_select]
+
+ self.connect_cells(pre_ps, post_ps, to_connect_pre, to_connect_post)
diff --git a/bsb/connectivity/geometric/shape_morphology_intersection.py b/bsb/connectivity/geometric/shape_morphology_intersection.py
new file mode 100644
index 00000000..6fb2b383
--- /dev/null
+++ b/bsb/connectivity/geometric/shape_morphology_intersection.py
@@ -0,0 +1,111 @@
+import numpy as np
+
+from ... import config
+from ...config import types
+from .. import ConnectionStrategy
+from .shape_shape_intersection import ShapeHemitype
+
+
+def _create_geometric_conn_arrays(branches, ids, coord):
+ morpho_points = 0
+ for b in branches:
+ morpho_points += len(b.points)
+ points_ids = np.empty([morpho_points, 3], dtype=int)
+ morpho_coord = np.empty([morpho_points, 3], dtype=float)
+ local_ptr = 0
+ for i, b in enumerate(branches):
+ points_ids[local_ptr : local_ptr + len(b.points), 0] = ids
+ points_ids[local_ptr : local_ptr + len(b.points), 1] = i
+ points_ids[local_ptr : local_ptr + len(b.points), 2] = np.arange(len(b.points))
+ tmp = b.points + coord
+ morpho_coord[local_ptr : local_ptr + len(b.points), :] = tmp
+ local_ptr += len(b.points)
+ return points_ids, morpho_coord
+
+
+@config.node
+class ShapeToMorphologyIntersection(ConnectionStrategy):
+ presynaptic = config.attr(type=ShapeHemitype, required=True)
+ affinity = config.attr(type=types.fraction(), required=True, hint=0.1)
+ """Ratio of apositions to keep over the total number of contact points"""
+ pruning_ratio = config.attr(type=types.fraction(), required=True, hint=0.1)
+ """Ratio of conections to keep over the total number of apositions"""
+
+ def get_region_of_interest(self, chunk):
+ lpost, upost = self.postsynaptic._get_rect_ext(tuple(chunk.dimensions))
+ pre_chunks = self.presynaptic.get_all_chunks()
+ tree = self.presynaptic._get_shape_boxtree(
+ pre_chunks,
+ )
+ post_mbb = [
+ np.concatenate(
+ [
+ (lpost + chunk) * chunk.dimensions,
+ np.maximum((upost + chunk), (lpost + chunk) + 1) * chunk.dimensions,
+ ]
+ )
+ ]
+
+ return [pre_chunks[i] for i in tree.query(post_mbb, unique=True)]
+
+ def connect(self, pre, post):
+ for pre_ps in pre.placement:
+ for post_ps in post.placement:
+ self._connect_type(pre_ps.cell_type, pre_ps, post_ps.cell_type, post_ps)
+
+ def _connect_type(self, pre_ct, pre_ps, post_ct, post_ps):
+ pre_pos = pre_ps.load_positions()
+ post_pos = post_ps.load_positions()
+
+ pre_shapes = self.presynaptic.shapes_composition.__copy__()
+
+ to_connect_pre = np.empty([0, 3], dtype=int)
+ to_connect_post = np.empty([0, 3], dtype=int)
+
+ morpho_set = post_ps.load_morphologies()
+ post_morphos = morpho_set.iter_morphologies(cache=True, hard_cache=True)
+
+ for post_id, (post_coord, morpho) in enumerate(zip(post_pos, post_morphos)):
+ # Get the branches
+ branches = morpho.get_branches()
+
+ # Build ids array from the morphology
+ post_points_ids, post_morpho_coord = _create_geometric_conn_arrays(
+ branches, post_id, post_coord
+ )
+
+ for pre_id, pre_coord in enumerate(pre_pos):
+ pre_shapes.translate(pre_coord)
+ mbb_check = pre_shapes.inside_mbox(post_morpho_coord)
+
+ if np.any(mbb_check):
+ inside_pts = pre_shapes.inside_shapes(post_morpho_coord[mbb_check])
+ # Find the morpho points inside the postsyn geometric shapes
+ if np.any(inside_pts):
+ local_selection = (post_points_ids[mbb_check])[inside_pts]
+ if self.affinity < 1.0 and len(local_selection) > 0:
+ nb_targets = np.max(
+ [1, int(np.floor(self.affinity * len(local_selection)))]
+ )
+ chosen_targets = np.random.choice(
+ local_selection.shape[0], nb_targets
+ )
+ local_selection = local_selection[chosen_targets, :]
+ selected_count = len(local_selection)
+ if selected_count > 0:
+ to_connect_post = np.vstack(
+ [to_connect_post, local_selection]
+ )
+ pre_tmp = np.full([selected_count, 3], -1, dtype=int)
+ pre_tmp[:, 0] = pre_id
+ to_connect_pre = np.vstack([to_connect_pre, pre_tmp])
+ pre_shapes.translate(-pre_coord)
+ if self.pruning_ratio < 1 and len(to_connect_pre) > 0:
+ ids_to_select = np.random.choice(
+ len(to_connect_pre),
+ int(np.floor(self.pruning_ratio * len(to_connect_pre))),
+ replace=False,
+ )
+ to_connect_pre = to_connect_pre[ids_to_select]
+ to_connect_post = to_connect_post[ids_to_select]
+ self.connect_cells(pre_ps, post_ps, to_connect_pre, to_connect_post)
diff --git a/bsb/connectivity/geometric/shape_shape_intersection.py b/bsb/connectivity/geometric/shape_shape_intersection.py
new file mode 100644
index 00000000..59426284
--- /dev/null
+++ b/bsb/connectivity/geometric/shape_shape_intersection.py
@@ -0,0 +1,159 @@
+import numpy as np
+
+from ... import config
+from ...config import types
+from ...trees import BoxTree
+from .. import ConnectionStrategy
+from ..strategy import Hemitype
+from .geometric_shapes import ShapesComposition
+
+
+@config.node
+class ShapeHemitype(Hemitype):
+ """
+ Class representing a population of cells to connect with a ConnectionStrategy.
+ These cells' morphology is implemented as a ShapesComposition.
+ """
+
+ shapes_composition = config.attr(type=ShapesComposition, required=True)
+ """
+ Composite shape representing the Hemitype.
+ """
+
+ def get_mbb(self, chunks, chunk_dimension):
+ """
+ Get the list of minimal bounding box containing all cells in the `ShapeHemitype`.
+
+ :param chunks: List of chunks containing the cell types
+ (see bsb.connectivity.strategy.Hemitype.get_all_chunks)
+ :type chunks: List[bsb.storage._chunks.Chunk]
+ :param chunk_dimension: Size of a chunk
+ :type chunk_dimension: float
+ :return: List of bounding boxes in the form [min_x, min_y, min_z, max_x, max_y, max_z]
+ for each chunk containing cells.
+ :rtype: List[numpy.ndarray[float, float, float, float, float, float]]
+ """
+ return [
+ np.concatenate(
+ [
+ self.shapes_composition.get_mbb_min()
+ + np.array(idx_chunk) * chunk_dimension,
+ self.shapes_composition.get_mbb_max()
+ + np.array(idx_chunk) * chunk_dimension,
+ ]
+ )
+ for idx_chunk in chunks
+ ]
+
+ def _get_shape_boxtree(self, chunks):
+ mbbs = self.get_mbb(chunks, chunks[0].dimensions)
+ return BoxTree(mbbs)
+
+
+@config.node
+class ShapeToShapeIntersection(ConnectionStrategy):
+ presynaptic = config.attr(type=ShapeHemitype, required=True)
+ postsynaptic = config.attr(type=ShapeHemitype, required=True)
+ affinity = config.attr(type=types.fraction(), required=True, hint=0.1)
+ """Ratio of apositions to keep over the total number of contact points"""
+ pruning_ratio = config.attr(type=types.fraction(), required=True, hint=0.1)
+ """Ratio of conections to keep over the total number of apositions"""
+
+ def get_region_of_interest(self, chunk):
+ # Filter postsyn chunks that overlap the presyn chunk.
+ post_chunks = self.postsynaptic.get_all_chunks()
+ tree = self.postsynaptic._get_shape_boxtree(
+ post_chunks,
+ )
+ pre_mbb = self.presynaptic.get_mbb(
+ self.presynaptic.get_all_chunks(), chunk.dimensions
+ )
+ return [post_chunks[i] for i in tree.query(pre_mbb, unique=True)]
+
+ def connect(self, pre, post):
+ for pre_ps in pre.placement:
+ for post_ps in post.placement:
+ self._connect_type(pre_ps.cell_type, pre_ps, post_ps.cell_type, post_ps)
+
+ def _connect_type(self, pre_ct, pre_ps, post_ct, post_ps):
+ pre_pos = pre_ps.load_positions()
+ post_pos = post_ps.load_positions()
+
+ pre_shapes_cache = self.presynaptic.shapes_composition.__copy__()
+ post_shapes_cache = self.postsynaptic.shapes_composition.__copy__()
+
+ to_connect_pre = np.empty([0, 3], dtype=int)
+ to_connect_post = np.empty([0, 3], dtype=int)
+
+ for pre_id, pre_coord in enumerate(pre_pos):
+ # Generate pre point cloud
+ pre_shapes_cache.translate(pre_coord)
+ pre_point_cloud = pre_shapes_cache.generate_point_cloud()
+
+ def find_mbb(coords):
+ maxima = np.max(coords, axis=0)
+ minima = np.min(coords, axis=0)
+ return minima, maxima
+
+ def BoxesOverlap(box1min, box1max, box2min, box2max):
+ return np.all((box1max >= box2min) & (box2max >= box1min))
+
+ pre_mbb_min = pre_shapes_cache.get_mbb_min()
+ pre_mbb_max = pre_shapes_cache.get_mbb_max()
+
+ points_per_cloud = int(len(pre_point_cloud) * self.affinity)
+ tmp_pre_selection = np.full(
+ [len(post_pos) * int(points_per_cloud), 3], -1, dtype=int
+ )
+ tmp_post_selection = np.full(
+ [len(post_pos) * int(points_per_cloud), 3], -1, dtype=int
+ )
+ ptr = 0
+ for post_id, post_coord in enumerate(post_pos):
+ post_shapes_cache.translate(post_coord)
+ post_mbb_min = post_shapes_cache.get_mbb_min()
+ post_mbb_max = post_shapes_cache.get_mbb_max()
+ boxes_overlap = BoxesOverlap(
+ post_mbb_min, post_mbb_max, pre_mbb_min, pre_mbb_max
+ )
+ if boxes_overlap:
+ # Compare pre and post mbbs
+ inside_mbbox = post_shapes_cache.inside_mbox(pre_point_cloud)
+ if np.any(inside_mbbox):
+ inside_pts = post_shapes_cache.inside_shapes(pre_point_cloud)
+ selected = pre_point_cloud[inside_pts]
+
+ def sizemod(q, aff):
+ ln = len(q)
+ return int(
+ np.floor(ln * aff) + (np.random.rand() < ((ln * aff) % 1))
+ )
+
+ selected = selected[
+ np.random.randint(
+ len(selected), size=sizemod(selected, self.affinity)
+ ),
+ :,
+ ]
+ n_synapses = len(selected)
+ if n_synapses > 0:
+ tmp_pre_selection[ptr : ptr + n_synapses, 0] = pre_id
+ tmp_post_selection[ptr : ptr + n_synapses, 0] = post_id
+ ptr += n_synapses
+ post_shapes_cache.translate(-post_coord)
+ if ptr > 0:
+ to_connect_pre = np.vstack([to_connect_pre, tmp_pre_selection[:ptr]])
+ to_connect_post = np.vstack([to_connect_post, tmp_post_selection[:ptr]])
+
+ pre_shapes_cache.translate(-pre_coord)
+
+ if self.pruning_ratio < 1 and len(to_connect_pre) > 0:
+ ids_to_select = np.random.choice(
+ len(to_connect_pre),
+ int(np.floor(self.pruning_ratio * len(to_connect_pre))),
+ replace=False,
+ )
+ to_connect_pre = to_connect_pre[ids_to_select]
+ to_connect_post = to_connect_post[ids_to_select]
+
+ self.connect_cells(pre_ps, post_ps, to_connect_pre, to_connect_post)
diff --git a/bsb/connectivity/strategy.py b/bsb/connectivity/strategy.py
index 17163cff..d4de5ccf 100644
--- a/bsb/connectivity/strategy.py
+++ b/bsb/connectivity/strategy.py
@@ -1,7 +1,10 @@
import abc
import typing
+from functools import cache
from itertools import chain
+import numpy as np
+
from .. import config
from .._util import ichain, obj_str_insert
from ..config import refs, types
@@ -45,6 +48,38 @@ class Hemitype:
take too much disk space or time otherwise.
"""
+ def get_all_chunks(self):
+ """
+ Get the list of all chunks where the cell types were placed
+
+ :return: List of Chunks
+ :rtype: List[bsb.storage._chunks.Chunk]
+ """
+ return [
+ c for ct in self.cell_types for c in ct.get_placement_set().get_all_chunks()
+ ]
+
+ @cache
+ def _get_rect_ext(self, chunk_size):
+ # Returns the lower and upper boundary Chunk of the box containing the cell type population,
+ # based on the cell type's morphology if it exists.
+ # This box is centered on the Chunk [0., 0., 0.].
+ # If no morphologies are associated to the cell types then the bounding box size is 0.
+ types = self.cell_types
+ loader = self.morpho_loader
+ ps_list = [ct.get_placement_set() for ct in types]
+ ms_list = [loader(ps) for ps in ps_list]
+ if not sum(map(len, ms_list)):
+ # No cells placed, return smallest possible RoI.
+ return [np.array([0, 0, 0]), np.array([0, 0, 0])]
+ metas = list(chain.from_iterable(ms.iter_meta(unique=True) for ms in ms_list))
+ # TODO: Combine morphology extension information with PS rotation information.
+ # Get the chunk coordinates of the boundaries of this chunk convoluted with the
+ # extension of the intersecting morphologies.
+ lbounds = np.min([m["ldc"] for m in metas], axis=0) // chunk_size
+ ubounds = np.max([m["mdc"] for m in metas], axis=0) // chunk_size
+ return lbounds, ubounds
+
class HemitypeCollection:
def __init__(self, hemitype, roi):
@@ -158,6 +193,15 @@ def connect_cells(self, pre_set, post_set, src_locs, dest_locs, tag=None):
cs.connect(pre_set, post_set, src_locs, dest_locs)
def get_region_of_interest(self, chunk):
+ """
+ Returns the list of chunks containing the potential postsynaptic neurons, based on a
+ chunk containing the presynaptic neurons.
+
+ :param chunk: Presynaptic chunk
+ :type chunk: bsb.storage._chunks.Chunk
+ :returns: List of postsynaptic chunks
+ :rtype: List[bsb.storage._chunks.Chunk]
+ """
pass
def queue(self, pool: "JobPool"):
diff --git a/bsb/core.py b/bsb/core.py
index ac40a4a1..4abd60bb 100644
--- a/bsb/core.py
+++ b/bsb/core.py
@@ -27,7 +27,7 @@
if typing.TYPE_CHECKING:
from .cell_types import CellType
from .config._config import NetworkNode as Network
- from .postprocessing import AfterPlacementHook
+ from .postprocessing import AfterConnectivityHook, AfterPlacementHook
from .simulation.simulation import Simulation
from .storage.interfaces import (
ConnectivitySet,
@@ -87,7 +87,7 @@ def _get_linked_config(storage=None):
path = cfg._meta.get("path", None)
if path and os.path.exists(path):
with open(path, "r") as f:
- cfg = bsb.config.parse_configuration_file(f)
+ cfg = bsb.config.parse_configuration_file(f, path=path)
return cfg
else:
return None
@@ -113,7 +113,7 @@ class Scaffold:
placement: typing.Dict[str, "PlacementStrategy"]
after_placement: typing.Dict[str, "AfterPlacementHook"]
connectivity: typing.Dict[str, "ConnectionStrategy"]
- after_connectivity: typing.Dict[str, "AfterPlacementHook"]
+ after_connectivity: typing.Dict[str, "AfterConnectivityHook"]
simulations: typing.Dict[str, "Simulation"]
def __init__(self, config=None, storage=None, clear=False, comm=None):
diff --git a/bsb/exceptions.py b/bsb/exceptions.py
index f06ffac8..5715a7ae 100644
--- a/bsb/exceptions.py
+++ b/bsb/exceptions.py
@@ -107,7 +107,10 @@
),
DataNotProvidedError=_e(),
PluginError=_e("plugin"),
- ParserError=_e(),
+ ParserError=_e(
+ FileImportError=_e(),
+ FileReferenceError=_e(),
+ ),
ClassError=_e(),
),
)
@@ -183,6 +186,8 @@ class PackageRequirementWarning(ScaffoldWarning):
"EmptySelectionError",
"EmptyVoxelSetError",
"ExternalSourceError",
+ "FileImportError",
+ "FileReferenceError",
"GatewayError",
"IncompleteExternalMapError",
"IncompleteMorphologyError",
diff --git a/bsb/morphologies/__init__.py b/bsb/morphologies/__init__.py
index 4835f56f..b5872e9d 100644
--- a/bsb/morphologies/__init__.py
+++ b/bsb/morphologies/__init__.py
@@ -31,12 +31,6 @@
from ..voxels import VoxelSet
-def parse_morphology_file(file, **kwargs):
- from .parsers import parse_morphology_file
-
- return parse_morphology_file(file, **kwargs)
-
-
class MorphologySet:
"""
Associates a set of :class:`StoredMorphologies
@@ -303,7 +297,7 @@ def branch_iter(branch):
class SubTree:
"""
- Collection of branches, not necesarily all connected.
+ Collection of branches, not necessarily all connected.
"""
def __init__(self, branches, sanitize=True):
@@ -537,8 +531,8 @@ def rotate(self, rotation, center=None):
"""
Point rotation
- :param rot: Scipy rotation
- :type: Union[scipy.spatial.transform.Rotation, List[float,float,float]]
+ :param rotation: Scipy rotation
+ :type rotation: Union[scipy.spatial.transform.Rotation, List[float,float,float]]
:param center: rotation offset point.
:type center: numpy.ndarray
"""
@@ -944,6 +938,26 @@ def as_filtered(self, labels=None):
# Construct and return the morphology
return self.__class__(roots, meta=self.meta.copy())
+ def swap_axes(self, axis1: int, axis2: int):
+ """
+ Interchange two axes of a morphology points.
+
+ :param int axis1: index of the first axis to exchange
+ :param int axis2: index of the second axis to exchange
+ :return: the modified morphology
+ :rtype: bsb.morphologies.Morphology
+ """
+ if not 0 <= axis1 < 3 or not 0 <= axis2 < 3:
+ raise ValueError(
+ f"Axes values should be in [0, 1, 2], {axis1}, {axis2} given."
+ )
+ for b in self.branches:
+ old_column = np.copy(b.points[:, axis1])
+ b.points[:, axis1] = b.points[:, axis2]
+ b.points[:, axis2] = old_column
+
+ return self
+
def simplify(self, *args, optimize=True, **kwargs):
super().simplify_branches(*args, **kwargs)
if optimize:
@@ -984,7 +998,7 @@ def to_graph_array(self):
def _copy_api(cls, wrap=lambda self: self):
- # Wraps functions so they are called with `self` wrapped in `wrap`
+ # Wraps functions, so they are called with `self` wrapped in `wrap`
def make_wrapper(f):
@functools.wraps(f)
def wrapper(self, *args, **kwargs):
@@ -1023,7 +1037,7 @@ def __init__(self, points, radii, labels=None, properties=None, children=None):
:param radii: Array of radii associated to each point
:type radii: list | numpy.ndarray
:param labels: Array of labels to associate to each point
- :type labels: EncodedLabels | List[str] | set | numpy.ndarray
+ :type labels: List[str] | set | numpy.ndarray
:param properties: dictionary of per-point data to store in the branch
:type properties: dict
:param children: list of child branches to attach to the branch
@@ -1710,5 +1724,4 @@ def _morpho_to_swc(morpho):
"RotationSet",
"SubTree",
"branch_iter",
- "parse_morphology_file",
]
diff --git a/bsb/services/pool.py b/bsb/services/pool.py
index 70bbc087..c48f41ac 100644
--- a/bsb/services/pool.py
+++ b/bsb/services/pool.py
@@ -412,7 +412,7 @@ def _dep_completed(self, dep):
self._enqueue(self._pool)
def _enqueue(self, pool):
- if not self._deps and self._status is not JobStatus.CANCELLED:
+ if not self._deps and self._status is JobStatus.PENDING:
# Go ahead and submit ourselves to the pool, no dependencies to wait for
# The dispatcher is run on the remote worker and unpacks the data required
# to execute the job contents.
diff --git a/bsb/storage/_files.py b/bsb/storage/_files.py
index 75c8f13a..f82bbac6 100644
--- a/bsb/storage/_files.py
+++ b/bsb/storage/_files.py
@@ -13,6 +13,7 @@
import urllib.parse as _up
import urllib.request as _ur
+import certifi as _cert
import nrrd as _nrrd
import requests as _rq
@@ -229,7 +230,7 @@ def resolve_uri(self, file: FileDependency):
def find(self, file: FileDependency):
with self.create_session() as session:
- response = session.head(self.resolve_uri(file))
+ response = session.head(self.resolve_uri(file), verify=_cert.where())
return response.status_code == 200
def should_update(self, file: FileDependency, stored_file):
@@ -251,12 +252,12 @@ def should_update(self, file: FileDependency, stored_file):
def get_content(self, file: FileDependency):
with self.create_session() as session:
- response = session.get(self.resolve_uri(file))
+ response = session.get(self.resolve_uri(file), verify=_cert.where())
return (response.content, response.encoding)
def get_meta(self, file: FileDependency):
with self.create_session() as session:
- response = session.head(self.resolve_uri(file))
+ response = session.head(self.resolve_uri(file), verify=_cert.where())
return {"headers": dict(response.headers)}
def get_local_path(self, file: FileDependency):
@@ -265,7 +266,9 @@ def get_local_path(self, file: FileDependency):
@_cl.contextmanager
def provide_stream(self, file):
with self.create_session() as session:
- response = session.get(self.resolve_uri(file), stream=True)
+ response = session.get(
+ self.resolve_uri(file), stream=True, verify=_cert.where()
+ )
response.raw.decode_content = True
response.raw.auto_close = False
yield (response.raw, response.encoding)
@@ -294,11 +297,11 @@ def get_nm_meta(self, file: FileDependency):
name = file.uri[idx : (idx + len(name))]
with self.create_session() as session:
try:
- res = session.get(self._nm_url + self._meta + name)
+ res = session.get(self._nm_url + self._meta + name, verify=_cert.where())
except Exception as e:
return {"archive": "none", "neuron_name": "none"}
if res.status_code == 404:
- res = session.get(self._nm_url)
+ res = session.get(self._nm_url, verify=_cert.where())
if res.status_code != 200 or "Service Interruption Notice" in res.text:
warn(f"NeuroMorpho.org is down, can't retrieve morphology '{name}'.")
return {"archive": "none", "neuron_name": "none"}
@@ -400,24 +403,44 @@ def get_stored_file(self):
@config.node
class CodeDependencyNode(FileDependencyNode):
+ """
+ Allow the loading of external code during network loading.
+ """
+
module: str = config.attr(type=str, required=types.shortform())
+ """Should be either the path to a python file or a import like string"""
attr: str = config.attr(type=str)
+ """Attribute to extract from the loaded script"""
@config.property
def file(self):
+ import os
+
if getattr(self, "scaffold", None) is not None:
file_store = self.scaffold.files
else:
file_store = None
- return FileDependency(
- self.module.replace(".", _os.sep) + ".py", file_store=file_store
- )
+ if os.path.isfile(self.module):
+ # Convert potential relative path to absolute path
+ module_file = os.path.abspath(os.path.join(os.getcwd(), self.module))
+ else:
+ # Module like string converted to a path string relative to current folder
+ module_file = "./" + self.module.replace(".", _os.sep) + ".py"
+ return FileDependency(module_file, file_store=file_store)
def __init__(self, module=None, **kwargs):
super().__init__(**kwargs)
if module is not None:
self.module = module
+ def __inv__(self):
+ if not isinstance(self, CodeDependencyNode):
+ return self
+ res = {"module": getattr(self, "module")}
+ if self.attr is not None:
+ res["attr"] = self.attr
+ return res
+
def load_object(self):
import importlib.util
import sys
@@ -429,7 +452,7 @@ def load_object(self):
module = importlib.util.module_from_spec(spec)
sys.modules[self.module] = module
spec.loader.exec_module(module)
- return module if self.attr is None else module[self.attr]
+ return module if self.attr is None else getattr(module, self.attr)
finally:
tmp = list(reversed(sys.path))
tmp.remove(_os.getcwd())
diff --git a/bsb/voxels.py b/bsb/voxels.py
index bc46326f..0d8808b9 100644
--- a/bsb/voxels.py
+++ b/bsb/voxels.py
@@ -637,4 +637,4 @@ def _squash_zero(arr):
return np.where(np.isclose(arr, 0), np.finfo(float).max, arr)
-__all__ = ["BoxTree", "VoxelData", "VoxelSet"]
+__all__ = ["VoxelData", "VoxelSet"]
diff --git a/docs/bsb/bsb.connectivity.geometric.rst b/docs/bsb/bsb.connectivity.geometric.rst
new file mode 100644
index 00000000..b7ba840a
--- /dev/null
+++ b/docs/bsb/bsb.connectivity.geometric.rst
@@ -0,0 +1,37 @@
+bsb.connectivity.geometric package
+==================================
+
+Submodules
+----------
+
+bsb.connectivity.geometric.geometric\_shapes module
+---------------------------------------------------
+
+.. automodule:: bsb.connectivity.geometric.geometric_shapes
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bsb.connectivity.geometric.morphology\_shape\_intersection
+----------------------------------------------------------
+
+.. automodule:: bsb.connectivity.geometric.morphology_shape_intersection
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bsb.connectivity.geometric.shape\_morphology\_intersection
+----------------------------------------------------------
+
+.. automodule:: bsb.connectivity.geometric.shape_morphology_intersection
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+bsb.connectivity.geometric.shape\_shape\_intersection
+-----------------------------------------------------
+
+.. automodule:: bsb.connectivity.geometric.shape_shape_intersection
+ :members:
+ :undoc-members:
+ :show-inheritance:
\ No newline at end of file
diff --git a/docs/bsb/bsb.connectivity.rst b/docs/bsb/bsb.connectivity.rst
index e0062629..5036200e 100644
--- a/docs/bsb/bsb.connectivity.rst
+++ b/docs/bsb/bsb.connectivity.rst
@@ -5,9 +5,10 @@ Subpackages
-----------
.. toctree::
- :maxdepth: 4
+ :maxdepth: 2
bsb.connectivity.detailed
+ bsb.connectivity.geometric
Submodules
----------
diff --git a/docs/conf.py b/docs/conf.py
index 6d5456b8..69f420f8 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -30,6 +30,10 @@
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
+
+autodoc_typehints = "both"
+
+
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.todo",
@@ -93,6 +97,7 @@
"getting-started/layer.rst",
]
+autoclass_content = "both"
# -- Options for HTML output -------------------------------------------------
diff --git a/docs/connectivity/connectivity-toc.rst b/docs/connectivity/connectivity-toc.rst
index 232589e9..05742ab7 100644
--- a/docs/connectivity/connectivity-toc.rst
+++ b/docs/connectivity/connectivity-toc.rst
@@ -5,3 +5,4 @@
defining
component
connection-strategies
+ geometric
diff --git a/docs/connectivity/geometric.rst b/docs/connectivity/geometric.rst
new file mode 100644
index 00000000..dac2e0fe
--- /dev/null
+++ b/docs/connectivity/geometric.rst
@@ -0,0 +1,221 @@
+.. _geometric:
+
+#######################################
+Geometric shape connectivity strategies
+#######################################
+
+To reconstruct with great details the connections between 2 neurons, one needs to provide the
+morphologies of these neurons. However, this data might be lacking or incomplete.
+Moreover, the reconstruction of a detailed connectivity is computationally expensive as the program
+have to find all apposition of the neurons arborizations.
+
+To resolve these two issues, neurons' morphology is here represented by a collection of geometric
+shapes representing the pre/postsynaptic neurites. The neurites apposition can be probabilistically
+approximated sampling point clouds from the shapes and checking the shape bounding box
+(see B1 in :ref:`Bibliography`).
+
+Creating simplified morphologies
+********************************
+
+The :class:`~bsb.connectivity.geometric.geometric_shapes.ShapesComposition` allows the simplified
+representation of cell morphologies. This class leverages a list geometric shapes
+(:class:`~bsb.connectivity.geometric.geometric_shapes.GeometricShape`) to represent ``sections``
+the cell morphology. Similarly to morphologies, labels should be associated to each of these
+``sections``. These labels will be used as reference during connectivity.
+
+For each ``section`` of the simplified morphology, the class samples a set of 3D points that belong
+to it. This cloud of points is used to detect connections between a source and target neuron.
+The points are uniformly distributed in the ``GeometricShape``, decomposing it into 3D voxels.
+The program generates as many points as the number of voxels in the volume of the shapes.
+
+Geometric shapes
+----------------
+
+Pre-defined GeometricShape implemented can be found in the ``~bsb.connectivity.geometric`` package.
+Each shape has its own set of parameters. We provide here an example of the configuration
+for a sphere:
+
+.. autoconfig:: bsb.connectivity.geometric.Sphere
+ :no-imports:
+ :max-depth: 1
+
+If needed, a user can define its own geometric shape, creating a new class inheriting from the base
+virtual class :class:`~bsb.connectivity.geometric.geometric_shapes.GeometricShape`.
+
+ShapesComposition
+-----------------
+To instantiate a :class:`~bsb.connectivity.geometric.geometric_shapes.ShapesComposition`, you need
+to provide a list of ``shapes`` together with their ``labels``: a list of lists of strings.
+``shapes`` and ``labels`` should have the same size. For each shape, multiple labels can be provided.
+You can additionally control the number of points sampled for connectivity with the parameter
+``voxel_size``. This parameter corresponds to the side length of one voxel used to decompose the
+shape collection.
+
+.. autoconfig:: bsb.connectivity.geometric.ShapesComposition
+ :no-imports:
+ :max-depth: 2
+
+Here, we represent the cell as a single sphere for the soma, a cone for the dendrites and a cylinder
+for the axon:
+
+.. code-block:: json
+
+ "my_neuron":
+ {
+ "voxel_size": 25,
+ "shapes":
+ [
+ {
+ "type": "sphere",
+ "radius": 40.0,
+ "center": [0., 0., 0.]},
+ {
+ "type": "cone",
+ "center": [0., 0., 0.],
+ "radius": 100.0,
+ "apex": [0., 100., 0.]},
+ {
+ "type": "cylinder",
+ "radius": 100.0,
+ "top_center": [0., 0., 0.],
+ "bottom_center": [0., 0., 10.]
+ }
+ ],
+ "labels":
+ [
+ ["soma"],
+ ["basal_dendrites", "apical_dendrites"],
+ ["axon"]
+ ],
+ }
+
+Geometric shape connectivity
+****************************
+
+The configuration of the geometric shape strategies are similar to the other connectivity strategies
+(see :class:`~bsb.connectivity.detailed.voxel_intersection.VoxelIntersection`).
+
+The ``ShapesComposition`` configuration should be provided with the field ``shape_compositions`` in
+the pre- and/or postsynaptic field (dependant on the strategy chosen).
+
+The parameters ``morphology_labels`` here specifies which shapes of the ``shape_compositions`` in
+:class:`~bsb.connectivity.geometric.geometric_shapes.ShapesComposition` must be used
+(corresponds to values stored in ``labels``).
+
+The ``affinity`` parameter controls the probability to form a connection.
+Three different connectivity strategies based on ``ShapesComposition`` are available.
+
+MorphologyToShapeIntersection
+-----------------------------
+
+The class :class:`~bsb.connectivity.geometric.morphology_shape_intersection.MorphologyToShapeIntersection`
+creates connections between the points of the morphology of the presynaptic cell and a geometric shape composition
+representing a postsynaptic cell, checking if the points of the morphology are inside the geometric
+shapes representing the postsynaptic cells.
+This connection strategy is suitable when we have a detailed morphology of the presynaptic cell, but
+not of the postsynaptic cell.
+
+Configuration example:
+
+.. code-block:: json
+
+ "stellate_to_purkinje":
+ {
+ "strategy": "bsb.connectivity.MorphologyToShapeIntersection",
+ "presynaptic": {
+ "cell_types": ["stellate_cell"],
+ "morphology_labels": ["axon"],
+ },
+ "postsynaptic": {
+ "cell_types": ["purkinje_cell"],
+ "morphology_labels": ["sc_targets"],
+ "shape_compositions" : [{
+ "voxel_size": 25,
+ "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}],
+ "labels": [["soma", "dendrites", "sc_targets", "axon"]],
+ }]
+ },
+ "affinity": 0.1,
+ "pruning_ratio": 0.5
+ }
+
+ShapeToMorphologyIntersection
+-----------------------------
+
+The class :class:`~bsb.connectivity.geometric.shape_morphology_intersection.ShapeToMorphologyIntersection`
+creates connections between the point cloud representing the presynaptic cell the points of the
+morphology of a postsynaptic cell, checking if the points of the morphology are inside the
+geometric shapes representing the presynaptic cells.
+This connection strategy is suitable when we have a detailed morphology of the postsynaptic cell,
+but not of the presynaptic cell.
+
+Configuration example:
+
+.. code-block:: json
+
+ "stellate_to_purkinje":
+ {
+ "strategy": "bsb.connectivity.ShapeToMorphologyIntersection",
+ "presynaptic": {
+ "cell_types": ["stellate_cell"],
+ "morphology_labels": ["axon"],
+ "shape_compositions" : [{
+ "voxel_size": 25,
+ "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}],
+ "labels": [["soma", "dendrites", "axon"]],
+ }]
+ },
+ "postsynaptic": {
+ "cell_types": ["purkinje_cell"],
+ "morphology_labels": ["sc_targets"]
+ },
+ "affinity": 0.1,
+ "pruning_ratio": 0.5
+ }
+
+ShapeToShapeIntersection
+------------------------
+
+The class :class:`~bsb.connectivity.geometric.shape_shape_intersection.ShapeToShapeIntersection`
+creates connections between the geometric shape compositions representing the presynaptic and postsynaptic cells.
+This strategy forms a connections generating a number of points inside the presynaptic probability
+point cloud and checking if they are inside the geometric shapes representing the postsynaptic cell.
+One point per voxel is generated.
+This connection strategy is suitable when we do not have a detailed morphology of neither the
+presynaptic nor the postsynaptic cell.
+
+Configuration example:
+
+.. code-block:: json
+
+ "stellate_to_purkinje":
+ {
+ "strategy": "bsb.connectivity.ShapeToShapeIntersection",
+ "presynaptic": {
+ "cell_types": ["stellate_cell"],
+ "morphology_labels": ["axon"],
+ "shape_compositions" : [{
+ "voxel_size": 25,
+ "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}],
+ "labels": [["soma", "dendrites", "axon"]],
+ }]
+ },
+ "postsynaptic": {
+ "cell_types": ["purkinje_cell"],
+ "morphology_labels": ["sc_targets"],
+ "shape_compositions" : [{
+ "voxel_size": 25,
+ "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}],
+ "labels": [["soma", "dendrites", "sc_targets", "axon"]],
+ }]
+ },
+ "affinity": 0.1,
+ "pruning_ratio": 0.7,
+ }
+
+.. _Bibliography:
+
+Bibliography
+************
+
+* B1: Gandolfi D, Mapelli J, Solinas S, De Schepper R, Geminiani A, Casellato C, D'Angelo E, Migliore M. A realistic morpho-anatomical connection strategy for modelling full-scale point-neuron microcircuits. Sci Rep. 2022 Aug 16;12(1):13864. doi: 10.1038/s41598-022-18024-y. Erratum in: Sci Rep. 2022 Nov 17;12(1):19792. PMID: 35974119; PMCID: PMC9381785.
\ No newline at end of file
diff --git a/docs/getting-started/installation.rst b/docs/getting-started/installation.rst
index a114da3e..8bb38f42 100644
--- a/docs/getting-started/installation.rst
+++ b/docs/getting-started/installation.rst
@@ -11,7 +11,7 @@ The BSB framework can be installed using ``pip``:
.. code-block:: bash
- pip install "bsb~=4.0"
+ pip install "bsb~=4.1"
You can verify that the installation works with:
@@ -47,7 +47,7 @@ To then install the BSB with parallel MPI support:
.. code-block:: bash
- pip install "bsb[parallel]~=4.0"
+ pip install "bsb[parallel]~=4.1"
Simulator backends
==================
diff --git a/pyproject.toml b/pyproject.toml
index ba9620e4..de0666b3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -100,7 +100,7 @@ line-length = 90
profile = "black"
[tool.bumpversion]
-current_version = "4.0.1"
+current_version = "4.1.1"
parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)"
serialize = ["{major}.{minor}.{patch}"]
search = "{current_version}"
diff --git a/tests/data/code_dependency.py b/tests/data/code_dependency.py
new file mode 100644
index 00000000..1feec292
--- /dev/null
+++ b/tests/data/code_dependency.py
@@ -0,0 +1 @@
+from bsb_test.configs.double_neuron import tree
diff --git a/tests/data/configs/basics.txt b/tests/data/configs/basics.txt
new file mode 100644
index 00000000..bf60cc48
--- /dev/null
+++ b/tests/data/configs/basics.txt
@@ -0,0 +1,5 @@
+{
+ "hello": "world",
+ "list": [1, 2, 3, "waddup"],
+ "nest me hard": {"oh yea": "just like that"},
+}
\ No newline at end of file
diff --git a/tests/data/configs/doubleref.txt b/tests/data/configs/doubleref.txt
new file mode 100644
index 00000000..8bf67684
--- /dev/null
+++ b/tests/data/configs/doubleref.txt
@@ -0,0 +1,8 @@
+{
+ "refs": {
+ "whats the": {
+ "$ref": "basics.txt#/nest me hard",
+ "$ref": "indoc_reference.txt#/target"
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/data/configs/far/targetme.bla b/tests/data/configs/far/targetme.bla
new file mode 100644
index 00000000..36fc2511
--- /dev/null
+++ b/tests/data/configs/far/targetme.bla
@@ -0,0 +1,25 @@
+<
+ "this": <
+ "key": <
+ "was": "in another folder",
+ "can": <
+ "i": <
+ "be": "imported"
+ >
+ >,
+ "with": <
+ "this": "string"
+ >,
+ "in": <
+ "my": "dict",
+ "or": <
+ "will": "it",
+ "give": "an",
+ "error": <
+ "on": "import"
+ >
+ >
+ >
+ >
+ >
+>
diff --git a/tests/data/configs/indoc_import.txt b/tests/data/configs/indoc_import.txt
new file mode 100644
index 00000000..647ce395
--- /dev/null
+++ b/tests/data/configs/indoc_import.txt
@@ -0,0 +1,20 @@
+{
+ "arr": {
+ "with": {},
+ "many": {},
+ "importable": {
+ "dicts": {
+ "that": "are",
+ "even": {
+ "nested": "eh"
+ }
+ }
+ }
+ },
+ "imp": {
+ "$import": {
+ "ref": "#/arr",
+ "values": ["with", "importable"]
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/data/configs/indoc_import_merge.txt b/tests/data/configs/indoc_import_merge.txt
new file mode 100644
index 00000000..65395735
--- /dev/null
+++ b/tests/data/configs/indoc_import_merge.txt
@@ -0,0 +1,28 @@
+{
+ "arr": {
+ "with": {},
+ "many": {},
+ "importable": {
+ "dicts": {
+ "that": "are",
+ "even": {
+ "nested": "eh"
+ },
+ "with": ["l", "i", "s", "t", "s"]
+ }
+ }
+ },
+ "imp": {
+ "$import": {
+ "ref": "#/arr",
+ "values": ["with", "importable"]
+ },
+ "importable": {
+ "diff": "added",
+ "dicts": {
+ "that": 4,
+ "with": ["new", "list"]
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/tests/data/configs/indoc_reference.txt b/tests/data/configs/indoc_reference.txt
new file mode 100644
index 00000000..634f15be
--- /dev/null
+++ b/tests/data/configs/indoc_reference.txt
@@ -0,0 +1,22 @@
+{
+ "get": {
+ "a": {
+ "secret": "key",
+ "nested secrets": {
+ "vim": "is hard",
+ "and": "convoluted"
+ }
+ }
+ },
+ "refs": {
+ "whats the": {
+ "$ref": "#/get/a"
+ },
+ "omitted_doc": {
+ "$ref": "/get/a"
+ }
+ },
+ "target": {
+ "for": "another"
+ }
+}
\ No newline at end of file
diff --git a/tests/data/configs/outdoc_import_merge.txt b/tests/data/configs/outdoc_import_merge.txt
new file mode 100644
index 00000000..6ea5abb7
--- /dev/null
+++ b/tests/data/configs/outdoc_import_merge.txt
@@ -0,0 +1,14 @@
+{
+ "imp": {
+ "importable":{
+ "$import": {
+ "ref": "indoc_import.txt#/imp/importable",
+ "values": ["dicts"]
+ },
+ }
+ },
+ "$import": {
+ "ref": "indoc_import_merge.txt#/",
+ "values": ["imp"]
+ }
+}
diff --git a/tests/test_configuration.py b/tests/test_configuration.py
index 6c1c2374..d890678a 100644
--- a/tests/test_configuration.py
+++ b/tests/test_configuration.py
@@ -1,5 +1,6 @@
import inspect
import json
+import os.path
import sys
import unittest
@@ -10,12 +11,14 @@
get_test_config,
list_test_configs,
)
+from bsb_test.configs import get_test_config_module
import bsb
from bsb import (
CastError,
CfgReferenceError,
ClassMapMissingError,
+ CodeDependencyNode,
ConfigurationError,
ConfigurationWarning,
DynamicClassInheritanceError,
@@ -28,6 +31,7 @@
UnfitClassCastError,
UnresolvedClassCastError,
config,
+ from_storage,
)
from bsb._package_spec import get_missing_requirement_reason
from bsb.config import Configuration, _attrs, compose_nodes, types
@@ -1253,6 +1257,37 @@ class TestClass:
with self.assertRaises(RequirementError):
TestClass(a="1", b="6", c="3")
+ def test_code_dependency_node(self):
+ @config.node
+ class Test:
+ c = config.attr(type=CodeDependencyNode)
+
+ module = get_test_config_module("double_neuron")
+ script = str(module.__file__)
+ # Test with a module like string
+ import_str = os.path.relpath(
+ os.path.join(os.path.dirname(__file__), "data/code_dependency")
+ ).replace(os.sep, ".")
+ b = Test(
+ c=import_str,
+ _parent=TestRoot(),
+ )
+ self.assertEqual(b.c.load_object().tree, module.tree)
+ # test with a file
+ b = Test(
+ c={"module": script},
+ _parent=TestRoot(),
+ )
+ # Test variable tree inside the file.
+ self.assertEqual(b.c.load_object().tree, module.tree)
+ self.assertEqual(b.__tree__(), {"c": {"module": script}})
+ # Test with relative path
+ b = Test(
+ c={"module": os.path.relpath(script), "attr": "tree"},
+ _parent=TestRoot(),
+ )
+ self.assertEqual(b.c.load_object(), module.tree)
+
@config.dynamic(
type=types.in_classmap(),
@@ -1649,7 +1684,7 @@ def test_composite_node(self):
assert type(self.tested.attrC == config.ConfigurationAttribute)
-class TestPackageRequirements(unittest.TestCase):
+class TestPackageRequirements(RandomStorageFixture, unittest.TestCase, engine_name="fs"):
def test_basic_version(self):
self.assertIsNone(get_missing_requirement_reason("bsb-core==" + bsb.__version__))
@@ -1669,3 +1704,15 @@ def test_uninstalled_package(self):
self.assertIsNotNone(get_missing_requirement_reason("bsb-core-soup==4.0"))
with self.assertWarns(PackageRequirementWarning):
Configuration.default(packages=["bsb-core-soup==4.0"])
+
+ def test_installed_package(self):
+ self.assertIsNone(get_missing_requirement_reason(f"bsb-core~={bsb.__version__}"))
+ # Should produce no warnings
+ cfg = Configuration.default(packages=[f"bsb-core~={bsb.__version__}"])
+ # Checking that the config with package requirements can be saved in storage
+ self.network = Scaffold(cfg, self.storage)
+ # Checking if the config
+ network2 = from_storage(self.storage.root)
+ self.assertEqual(
+ self.network.configuration.packages, network2.configuration.packages
+ )
diff --git a/tests/test_connectivity.py b/tests/test_connectivity.py
index c864e3a9..0705dbb7 100644
--- a/tests/test_connectivity.py
+++ b/tests/test_connectivity.py
@@ -457,7 +457,6 @@ def connect_spy(strat, pre, post):
)
-@unittest.skip("https://github.com/dbbs-lab/bsb-core/issues/820")
class TestVoxelIntersection(
RandomStorageFixture,
NetworkFixture,
@@ -680,7 +679,7 @@ def test_multi_indegree(self):
self.network.compile()
for post_name in ("inhibitory", "extra"):
post_ps = self.network.get_placement_set(post_name)
- total = np.zeros(len(post_ps))
+ total = np.zeros(len(post_ps), dtype=int)
for pre_name in ("excitatory", "extra"):
cs = self.network.get_connectivity_set(
f"multi_indegree_{pre_name}_to_{post_name}"
@@ -688,7 +687,7 @@ def test_multi_indegree(self):
_, post_locs = cs.load_connections().all()
ps = self.network.get_placement_set("inhibitory")
u, c = np.unique(post_locs[:, 0], return_counts=True)
- this = np.zeros(len(post_ps))
+ this = np.zeros(len(post_ps), dtype=int)
this[u] = c
total += this
self.assertTrue(np.all(total == 50), "Not all cells have indegree 50")
diff --git a/tests/test_geometric_connectivity.py b/tests/test_geometric_connectivity.py
new file mode 100644
index 00000000..0e2d7004
--- /dev/null
+++ b/tests/test_geometric_connectivity.py
@@ -0,0 +1,214 @@
+import unittest
+
+from bsb_test import (
+ FixedPosConfigFixture,
+ MorphologiesFixture,
+ NetworkFixture,
+ NumpyTestCase,
+ RandomStorageFixture,
+)
+
+from bsb import Configuration, DatasetNotFoundError, Scaffold
+from bsb.connectivity import (
+ MorphologyToShapeIntersection,
+ ShapeToMorphologyIntersection,
+ ShapeToShapeIntersection,
+)
+
+
+class TestShapeConnectivity(
+ RandomStorageFixture,
+ FixedPosConfigFixture,
+ NetworkFixture,
+ MorphologiesFixture,
+ NumpyTestCase,
+ unittest.TestCase,
+ engine_name="hdf5",
+ morpho_filters=["2branch"],
+):
+ def setUp(self):
+ super().setUp()
+
+ self.cfg = Configuration.default(
+ cell_types=dict(
+ test_cell_morpho=dict(
+ spatial=dict(
+ radius=1, density=1, morphologies=[dict(names=["2branch"])]
+ )
+ ),
+ test_cell_pc_1=dict(spatial=dict(radius=1, density=1)),
+ test_cell_pc_2=dict(spatial=dict(radius=1, density=1)),
+ ),
+ placement=dict(
+ fixed_pos_morpho=dict(
+ strategy="bsb.placement.FixedPositions",
+ cell_types=["test_cell_morpho"],
+ partitions=[],
+ positions=[[0, 0, 0], [0, 0, 100], [50, 0, 0], [0, -100, 0]],
+ ),
+ fixed_pos_pc_1=dict(
+ strategy="bsb.placement.FixedPositions",
+ cell_types=["test_cell_pc_1"],
+ partitions=[],
+ positions=[[40, 40, 40]],
+ ),
+ fixed_pos_pc_2=dict(
+ strategy="bsb.placement.FixedPositions",
+ cell_types=["test_cell_pc_2"],
+ partitions=[],
+ positions=[[0, -100, 0]],
+ ),
+ ),
+ )
+
+ self.network = Scaffold(self.cfg, self.storage)
+ self.network.compile(skip_connectivity=True)
+
+ def test_shape_to_shape(self):
+ voxel_size = 25
+ config_sphere = dict(type="sphere", radius=40.0, origin=[0, 0, 0])
+ ball_shape = {
+ "voxel_size": voxel_size,
+ "shapes": [config_sphere],
+ "labels": [["sphere"]],
+ }
+ # All the points of the presyn shape are inside the postsyn shape
+ self.network.connectivity["shape_to_shape_1"] = ShapeToShapeIntersection(
+ presynaptic=dict(
+ cell_types=["test_cell_pc_1"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ postsynaptic=dict(
+ cell_types=["test_cell_pc_1"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ affinity=0.9,
+ pruning_ratio=0.1,
+ )
+
+ # There are no intersections between the presyn and postsyn shapes
+ self.network.connectivity["shape_to_shape_2"] = ShapeToShapeIntersection(
+ presynaptic=dict(
+ cell_types=["test_cell_pc_1"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ postsynaptic=dict(
+ cell_types=["test_cell_pc_2"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ affinity=0.9,
+ pruning_ratio=0.1,
+ )
+
+ self.network.compile(skip_placement=True, append=True)
+
+ cs = self.network.get_connectivity_set("shape_to_shape_1")
+ con = cs.load_connections().all()[0]
+ intersection_points = len(con)
+ self.assertGreater(
+ intersection_points,
+ 0,
+ "expected at least one intersection point",
+ )
+
+ with self.assertRaises(DatasetNotFoundError):
+ # No connectivity set expected because no overlap of the populations' chunks.
+ self.network.get_connectivity_set("shape_to_shape_2")
+
+ def test_shape_to_morpho(self):
+ voxel_size = 25
+ config_sphere = dict(type="sphere", radius=40.0, origin=[0, 0, 0])
+ ball_shape = {
+ "voxel_size": voxel_size,
+ "shapes": [config_sphere],
+ "labels": [["sphere"]],
+ }
+
+ # We know a priori that there are intersections between the presyn shape and the morphology
+ self.network.connectivity["shape_to_morpho_1"] = ShapeToMorphologyIntersection(
+ presynaptic=dict(
+ cell_types=["test_cell_pc_2"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ postsynaptic=dict(cell_types=["test_cell_morpho"]),
+ affinity=0.5,
+ pruning_ratio=0.5,
+ )
+
+ # There are no intersections between the presyn shape and the morpho
+ self.network.connectivity["shape_to_morpho_2"] = ShapeToMorphologyIntersection(
+ presynaptic=dict(
+ cell_types=["test_cell_pc_1"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ postsynaptic=dict(cell_types=["test_cell_morpho"]),
+ affinity=0.5,
+ pruning_ratio=0.5,
+ )
+
+ self.network.compile(skip_placement=True, append=True)
+
+ cs = self.network.get_connectivity_set("shape_to_morpho_1")
+ con = cs.load_connections().all()[0]
+ intersection_points = len(con)
+ self.assertGreater(
+ intersection_points, 0, "expected at least one intersection point"
+ )
+
+ cs = self.network.get_connectivity_set("shape_to_morpho_2")
+ con = cs.load_connections().all()[0]
+ intersection_points = len(con)
+ self.assertClose(0, intersection_points, "expected no intersection points")
+
+ def test_morpho_to_shape(self):
+ voxel_size = 25
+ config_sphere = dict(type="sphere", radius=40.0, origin=[0, 0, 0])
+ ball_shape = {
+ "voxel_size": voxel_size,
+ "shapes": [config_sphere],
+ "labels": [["sphere"]],
+ }
+
+ # We know a priori that there are intersections between the presyn shape and the morphology
+ self.network.connectivity["shape_to_morpho_1"] = MorphologyToShapeIntersection(
+ postsynaptic=dict(
+ cell_types=["test_cell_pc_2"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ presynaptic=dict(cell_types=["test_cell_morpho"]),
+ affinity=0.5,
+ pruning_ratio=0.5,
+ )
+
+ # There are no intersections between the presyn shape and the morphology.
+ self.network.connectivity["shape_to_morpho_2"] = MorphologyToShapeIntersection(
+ postsynaptic=dict(
+ cell_types=["test_cell_pc_1"],
+ shapes_composition=ball_shape,
+ morphology_labels=["soma"],
+ ),
+ presynaptic=dict(cell_types=["test_cell_morpho"]),
+ affinity=0.5,
+ pruning_ratio=0.5,
+ )
+
+ self.network.compile(skip_placement=True, append=True)
+
+ cs = self.network.get_connectivity_set("shape_to_morpho_1")
+ con = cs.load_connections().all()[0]
+ intersection_points = len(con)
+ self.assertGreater(
+ intersection_points, 0, "expected at least one intersection point"
+ )
+
+ cs = self.network.get_connectivity_set("shape_to_morpho_2")
+ con = cs.load_connections().all()[0]
+ intersection_points = len(con)
+ self.assertClose(0, intersection_points, "expected no intersection points")
diff --git a/tests/test_geometric_shapes.py b/tests/test_geometric_shapes.py
new file mode 100644
index 00000000..e6b5aed3
--- /dev/null
+++ b/tests/test_geometric_shapes.py
@@ -0,0 +1,707 @@
+import unittest
+
+import numpy as np
+from bsb_test import NumpyTestCase
+
+from bsb import RequirementError
+from bsb.connectivity import (
+ Cone,
+ Cuboid,
+ Cylinder,
+ Ellipsoid,
+ Parallelepiped,
+ ShapesComposition,
+ Sphere,
+)
+
+
+class TestGeometricShapes(unittest.TestCase, NumpyTestCase):
+ def _check_points_inside(self, sc, volume, voxel_size):
+ self.assertClose(np.sum(sc.get_volumes()), volume)
+ expected_number_of_points = int(volume / voxel_size**3)
+ point_cloud = sc.generate_point_cloud()
+ npoints = len(point_cloud)
+
+ # Check the number of points in the point cloud
+ self.assertEqual(
+ npoints,
+ expected_number_of_points,
+ "The number of point in the point cloud is not the expected one",
+ )
+
+ # Check if the point cloud is inside the sphere
+ points_inside_sphere = sc.inside_shapes(point_cloud)
+ all_points_inside = np.all(points_inside_sphere)
+ self.assertEqual(
+ all_points_inside,
+ True,
+ "The point cloud should be inside the ShapeComposition",
+ )
+
+ def _check_translation(self, sc, expected_mbb):
+ # Check translation
+ translation_vec = np.array([1.0, 10.0, 100.0])
+ sc.translate(translation_vec)
+ mbb = sc.find_mbb()
+ expected_mbb += translation_vec
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ sc.translate(-translation_vec)
+ expected_mbb -= translation_vec
+
+ # Create a sphere, add it to a ShapeComposition object and test the minimal bounding box,
+ # inside_mbox, inside_shapes and generate_point_cloud methods
+ def test_sphere(self):
+ # Create a ShapesComposition object; In this test the size of the voxel is not important.
+ conf = dict(voxel_size=25, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+ self.assertEqual(None, sc.generate_point_cloud())
+ self.assertEqual(None, sc.inside_shapes(np.array([[0.0, 0.0, 0.0]])))
+ self.assertEqual(None, sc.generate_wireframe())
+ self.assertEqual([], sc.get_volumes())
+
+ # Add the sphere to the ShapesComposition object
+ radius = 100.0
+ origin = np.array([0, 0, 0], dtype=np.float64)
+ configuration = dict(radius=radius, origin=origin)
+ sc.add_shape(Sphere(configuration), ["sphere"])
+
+ # Find the mmb
+ mbb = sc.find_mbb()
+ expected_mbb = np.array(
+ [[-100.0, -100.0, -100.0], [100.0, 100.0, 100.0]], dtype=np.float64
+ )
+
+ # If the result is correct the mmb is the box individuated by
+ # the opposite vertices [-100., -100., 0.] and [100., 100., 100.].
+ # The tuple must be in the correct order.
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # The point of coordinates (0,0,50) is inside the sphere, while (200,200,200) is not.
+ # We test check_mbox, check_inside with these two points
+ point_to_check = np.array([[0, 0, 50], [200, 200, 200]], dtype=np.float64)
+ inside_mbox = sc.inside_mbox(point_to_check)
+ inside_shape = sc.inside_shapes(point_to_check)
+
+ # Check if the points are inside the mbb
+ expected_inside_mbox = [True, False]
+ self.assertEqual(
+ inside_mbox[0],
+ expected_inside_mbox[0],
+ "The point (0,0,50) should be inside the minimal bounding box",
+ )
+ self.assertEqual(
+ inside_mbox[1],
+ expected_inside_mbox[1],
+ "The point (200,200,200) should be outside the minimal bounding box",
+ )
+
+ # Check if the points are inside the sphere
+ expected_inside_shape = [True, False]
+ self.assertEqual(
+ inside_shape[0],
+ expected_inside_shape[0],
+ "The point (0,0,50) should be inside the sphere",
+ )
+ self.assertEqual(
+ inside_shape[1],
+ expected_inside_shape[1],
+ "The point (0,0,-50) should be outside the sphere",
+ )
+
+ # Test generate_point_cloud method.
+ # The expected number of points is given by the volume of the sphere divided by the voxel
+ # side to the third.
+ # The points should be inside the sphere.
+ volume = 4 * (np.pi * configuration["radius"] ** 3) / 3.0
+ self._check_points_inside(sc, volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
+ wireframe = np.array(sc.generate_wireframe())
+ self.assertEqual((3, 1, 30, 30), wireframe.shape)
+ self.assertTrue(
+ np.allclose(np.linalg.norm(wireframe[:, 0, :, 0].T - origin, axis=1), radius)
+ )
+
+ # Create an ellipsoid, add it to a ShapeComposition object and test the minimal bounding box,
+ # inside_mbox, inside_shapes and generate_point_cloud methods
+ def test_ellipsoid(self):
+ # Create a ShapesComposition object; In this test the size of the voxel is not important.
+ conf = dict(voxel_size=25, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+
+ # Add the ellipsoid to the ShapesComposition object
+ configuration = dict(
+ origin=np.array([0, 0, 0], dtype=np.float64),
+ lambdas=np.array([50, 100, 10], dtype=np.float64),
+ v0=np.array([1, 0, 0], dtype=np.float64),
+ v1=np.array([0, 1, 0], dtype=np.float64),
+ v2=np.array([0, 0, 1], dtype=np.float64),
+ )
+ ellipsoid = Ellipsoid(configuration)
+ sc.add_shape(ellipsoid, ["ellipsoid"])
+
+ # Find the mmb
+ mbb = sc.find_mbb()
+ expected_mbb = np.array(
+ [[-50.0, -100.0, -10.0], [50.0, 100.0, 10.0]], dtype=np.float64
+ )
+
+ # If the result is correct the mmb is the box individuated by
+ # the opposite vertices [-100., -100., -10.] and [100., 100., 10.].
+ # The tuple must be in the correct order.
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # The point of coordinates (0,0,5) is inside the ellipsoid, while (20,20,20) is not.
+ # We test check_mbox, check_inside with these two points
+ point_to_check = np.array([[0, 0, 5], [20, 20, 20]], dtype=np.float64)
+ inside_mbox = sc.inside_mbox(point_to_check)
+ inside_shape = sc.inside_shapes(point_to_check)
+
+ # Check if the points are inside the mbb
+ expected_inside_mbox = [True, False]
+ self.assertEqual(
+ inside_mbox[0],
+ expected_inside_mbox[0],
+ "The point (0,0,50) should be inside the minimal bounding box",
+ )
+ self.assertEqual(
+ inside_mbox[1],
+ expected_inside_mbox[1],
+ "The point (200,200,200) should be outside the minimal bounding box",
+ )
+
+ # Check if the points are inside the ellipsoid
+ expected_inside_shape = [True, False]
+ self.assertEqual(
+ inside_shape[0],
+ expected_inside_shape[0],
+ "The point (0,0,50) should be inside the ellipsoid",
+ )
+ self.assertEqual(
+ inside_shape[1],
+ expected_inside_shape[1],
+ "The point (0,0,-50) should be outside the ellipsoid",
+ )
+
+ # Test generate_point_cloud method.
+ # The expected number of points is given by the volume of the ellipsoid divided by the voxel side to the third
+ # The points should be inside the ellipsoid.
+ volume = (
+ np.pi
+ * configuration["lambdas"][0]
+ * configuration["lambdas"][1]
+ * configuration["lambdas"][2]
+ )
+ self._check_points_inside(sc, volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
+
+ # Check rotation
+ ellipsoid.rotate(np.array([0.0, 0.0, 1.0]), np.pi / 2)
+ mbb = ellipsoid.find_mbb()
+ self.assertClose(mbb[0], expected_mbb[0, [1, 0, 2]])
+ self.assertClose(mbb[1], expected_mbb[1, [1, 0, 2]])
+
+ wireframe = np.array(sc.generate_wireframe())
+ self.assertEqual((3, 1, 30, 30), wireframe.shape)
+ for coord in (wireframe[:, 0, :, 0] - ellipsoid.origin[..., np.newaxis]).T:
+ self.assertTrue(-1e-5 <= coord[0] <= 1e5)
+ self.assertTrue(-1e-5 <= coord[1] <= 50)
+ self.assertTrue(-10 <= coord[2] <= 10)
+
+ # Create a cylinder, add it to a ShapeComposition object and test the minimal bounding box,
+ # inside_mbox, inside_shapes and generate_point_cloud methods
+ def test_cylinder(self):
+ # Create a ShapesComposition object; In this test the size of the voxel is not important.
+ conf = dict(voxel_size=25, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+
+ # Add the cylinder to the ShapesComposition object
+ radius = 100.0
+ origin = np.array([0, 0, 0], dtype=np.float64)
+ top_center = np.array([0, 0, 10], dtype=np.float64)
+
+ configuration = dict(
+ radius=radius,
+ origin=origin,
+ top_center=top_center,
+ )
+ cylinder = Cylinder(configuration)
+ sc.add_shape(cylinder, ["cylinder"])
+
+ # Find the mmb
+ mbb = sc.find_mbb()
+ expected_mbb = np.array(
+ [[-100.0, -100.0, 0.0], [100.0, 100.0, 10.0]], dtype=np.float64
+ )
+
+ # If the result is correct the mmb is the box individuated by
+ # the opposite vertices [-100., -100., 0.] and [100., 100., 100.].
+ # The tuple must be in the correct order.
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # The point of coordinates (0,0,5) is inside the cylinder, while (200,200,200) is not.
+ # We test check_mbox, check_inside with these two points
+ point_to_check = np.array([[0, 0, 5], [200, 200, 200]], dtype=np.float64)
+ inside_mbox = sc.inside_mbox(point_to_check)
+ inside_shape = sc.inside_shapes(point_to_check)
+
+ # Check if the points are inside the mbb
+ expected_inside_mbox = [True, False]
+ self.assertEqual(
+ inside_mbox[0],
+ expected_inside_mbox[0],
+ "The point (0,0,50) should be inside the minimal bounding box",
+ )
+ self.assertEqual(
+ inside_mbox[1],
+ expected_inside_mbox[1],
+ "The point (200,200,200) should be outside the minimal bounding box",
+ )
+
+ # Check if the points are inside the cylinder
+ expected_inside_shape = [True, False]
+ self.assertEqual(
+ inside_shape[0],
+ expected_inside_shape[0],
+ "The point (0,0,50) should be inside the cylinder",
+ )
+ self.assertEqual(
+ inside_shape[1],
+ expected_inside_shape[1],
+ "The point (0,0,-50) should be outside the cylinder",
+ )
+
+ # Test generate_point_cloud method.
+ # The expected number of points is given by the volume of the cylinder divided by the voxel side to the third
+ # The points should be inside the cylinder.
+ height = np.linalg.norm(configuration["top_center"] - configuration["origin"])
+ volume = np.pi * height * configuration["radius"] ** 2
+ self._check_points_inside(sc, volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
+
+ # Check rotation
+ cylinder.rotate(np.array([1.0, 0.0, 0.0]), np.pi / 2)
+ mbb = cylinder.find_mbb()
+ self.assertClose(mbb[0], [-100.0, -10.0, -100.0])
+ self.assertClose(mbb[1], [100.0, 0.0, 100.0])
+
+ wireframe = np.array(sc.generate_wireframe())
+ self.assertEqual((3, 1, 30, 30), wireframe.shape)
+ self.assertTrue(
+ np.allclose(
+ np.linalg.norm(
+ (wireframe[:, 0, :, 0] - cylinder.origin[..., np.newaxis])[
+ np.array([0, 2])
+ ],
+ axis=0,
+ ),
+ radius,
+ )
+ )
+ for p in wireframe[:, 0, :, 0].T:
+ self.assertTrue(
+ 0
+ <= np.absolute(p[1] - cylinder.origin[1])
+ <= np.absolute(cylinder.top_center[1] - cylinder.origin[1])
+ )
+
+ # Create a parallelepiped, add it to a ShapeComposition object and test the minimal bounding
+ # box, inside_mbox, inside_shapes and generate_point_cloud methods
+ def test_parallelepiped(self):
+ # Create a ShapesComposition object; In this test the size of the voxel is not important.
+ conf = dict(voxel_size=5, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+
+ # Add the parallelepiped to the ShapesComposition object
+ configuration = dict(
+ origin=np.array([-5, -5, -5], dtype=np.float64),
+ side_vector_1=np.array([10, 0, 0], dtype=np.float64),
+ side_vector_2=np.array([0, 100, 0], dtype=np.float64),
+ side_vector_3=np.array([0, 0, 10], dtype=np.float64),
+ )
+ parallelepiped = Parallelepiped(configuration)
+ sc.add_shape(
+ parallelepiped,
+ ["parallelepiped"],
+ )
+
+ # Find the mmb
+ mbb = sc.find_mbb()
+ expected_mbb = np.array([[-5.0, -5.0, -5.0], [5.0, 95.0, 5.0]], dtype=np.float64)
+
+ # If the result is correct the mmb is the box individuated by
+ # the opposite vertices [-5., -5., -5.] and [5., 5., 5.].
+ # The tuple must be in the correct order.
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # The point of coordinates (0,0,0) is inside the parallelepiped, while (10,10,10) is not.
+ # We test check_mbox, check_inside with these two points
+ point_to_check = np.array([[0, 0, 0], [10, 100, 10]], dtype=np.float64)
+ inside_mbox = sc.inside_mbox(point_to_check)
+ inside_shape = sc.inside_shapes(point_to_check)
+
+ # Check if the points are inside the mbb
+ expected_inside_mbox = [True, False]
+ self.assertEqual(
+ inside_mbox[0],
+ expected_inside_mbox[0],
+ "The point (0,0,0) should be inside the minimal bounding box",
+ )
+ self.assertEqual(
+ inside_mbox[1],
+ expected_inside_mbox[1],
+ "The point (10,10,10) should be outside the minimal bounding box",
+ )
+
+ # Check if the points are inside the parallelepiped
+ expected_inside_shape = [True, False]
+ self.assertEqual(
+ inside_shape[0],
+ expected_inside_shape[0],
+ "The point (0,0,0) should be inside the parallelepiped",
+ )
+ self.assertEqual(
+ inside_shape[1],
+ expected_inside_shape[1],
+ "The point (10,100,10) should be outside the parallelepiped",
+ )
+
+ # Test generate_point_cloud method.
+ # The expected number of points is given by the volume of the parallelepiped divided by the
+ # voxel side to the third.
+ # The points should be inside the parallelepiped.
+ volume = (
+ np.linalg.norm(configuration["side_vector_1"])
+ * np.linalg.norm(configuration["side_vector_2"])
+ * np.linalg.norm(configuration["side_vector_3"])
+ )
+ self._check_points_inside(sc, volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
+
+ # Check rotation
+ parallelepiped.rotate(np.array([1.0, 0.0, 0.0]), np.pi / 2)
+ mbb = parallelepiped.find_mbb()
+ expected_mbb = np.array(
+ [[-5.0, -15.0, -5.0], [5.0, -5.0, 95.0]], dtype=np.float64
+ )
+ self.assertClose(mbb[0], expected_mbb[0])
+ self.assertClose(mbb[1], expected_mbb[1])
+ wireframe = np.array(sc.generate_wireframe())
+ self.assertEqual((3, 1, 4, 4), wireframe.shape)
+ self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[0] >= -1e-5))
+ self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[1] <= 1e-5))
+
+ # Create a cuboid, add it to a ShapeComposition object and test the minimal bounding box,
+ # inside_mbox, inside_shapes and generate_point_cloud methods
+ def test_cuboid(self):
+ # Create a ShapesComposition object; In this test the size of the voxel is not important.
+ conf = dict(voxel_size=25, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+
+ # Add the cuboid to the ShapesComposition object
+ configuration = dict(
+ origin=np.array([0, 0, 0], dtype=np.float64),
+ side_length_1=5.0,
+ side_length_2=10.0,
+ top_center=np.array([0, 0, 20], dtype=np.float64),
+ )
+ cuboid = Cuboid(configuration)
+ sc.add_shape(cuboid, ["cuboid"])
+
+ # Find the mmb
+ mbb = sc.find_mbb()
+ expected_mbb = np.array([[-2.5, -5.0, 0.0], [2.5, 5.0, 20.0]], dtype=np.float64)
+
+ # If the result is correct the mmb is the box individuated by
+ # the opposite vertices [-5., -5., -5.] and [5., 5., 5.].
+ # The tuple must be in the correct order.
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # The point of coordinates (1,1,1) is inside the cuboid, while (10,10,10) is not.
+ # We test check_mbox, check_inside with these two points
+ point_to_check = np.array([[1, 1, 1], [10, 10, 10]], dtype=np.float64)
+ inside_mbox = sc.inside_mbox(point_to_check)
+ inside_shape = sc.inside_shapes(point_to_check)
+
+ # Check if the points are inside the mbb
+ expected_inside_mbox = [True, False]
+ self.assertEqual(
+ inside_mbox[0],
+ expected_inside_mbox[0],
+ "The point (0,0,0) should be inside the minimal bounding box",
+ )
+ self.assertEqual(
+ inside_mbox[1],
+ expected_inside_mbox[1],
+ "The point (10,10,10) should be outside the minimal bounding box",
+ )
+
+ # Check if the points are inside the cuboid
+ expected_inside_shape = [True, False]
+ self.assertEqual(
+ inside_shape[0],
+ expected_inside_shape[0],
+ "The point (0,0,0) should be inside the cuboid",
+ )
+ self.assertEqual(
+ inside_shape[1],
+ expected_inside_shape[1],
+ "The point (10,10,10) should be outside the cuboid",
+ )
+
+ # Test generate_point_cloud method.
+ # The expected number of points is given by the volume of the cuboid divided by the voxel side to the third
+ # The points should be inside the cuboid.
+ volume = (
+ configuration["side_length_1"]
+ * configuration["side_length_2"]
+ * np.linalg.norm(configuration["top_center"])
+ )
+ self._check_points_inside(sc, volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
+
+ # Check rotation
+ expected_mbb = np.array([[-2.5, -5.0, 0.0], [2.5, 5.0, 20.0]], dtype=np.float64)
+ cuboid.rotate(np.array([0.0, 0.0, 1.0]), np.pi / 2)
+ mbb = cuboid.find_mbb()
+ self.assertClose(mbb[0], expected_mbb[0])
+ self.assertClose(mbb[1], expected_mbb[1])
+ wireframe = np.array(sc.generate_wireframe())
+ self.assertEqual((3, 1, 4, 4), wireframe.shape)
+ self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[0] >= -1e-5))
+ self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[1] <= 1e-5))
+
+ # Create a cone, add it to a ShapeComposition object and test the minimal bounding box,
+ # inside_mbox, inside_shapes and generate_point_cloud methods
+ def test_cone(self):
+ # Create a ShapesComposition object; In this test the size of the voxel is not important.
+ conf = dict(voxel_size=50, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+
+ # Add the cone to the ShapesComposition object
+ radius = 100.0
+ origin = np.array([0, 0, 100], dtype=np.float64)
+ apex = np.array([0, 0, 0], dtype=np.float64)
+ configuration = {
+ "origin": origin,
+ "radius": radius,
+ "apex": apex,
+ }
+ cone = Cone(configuration)
+ sc.add_shape(cone, ["cone"])
+
+ # Find the mmb
+ mbb = sc.find_mbb()
+ expected_mbb = np.array(
+ [[-100.0, -100.0, 0.0], [100.0, 100.0, 100.0]], dtype=np.float64
+ )
+
+ # If the result is correct the mmb is the box individuated by
+ # the opposite vertices [-100., -100., 0.] and [100., 100., 100.].
+ # The tuple must be in the correct order.
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # The point of coordinates (0,0,50) is inside the cone, while (0,0,-50) is not.
+ # We test check_mbox, check_inside with these two points
+ point_to_check = np.array([[0, 0, 50], [0, 0, -50]], dtype=np.float64)
+ inside_mbox = sc.inside_mbox(point_to_check)
+ inside_shape = sc.inside_shapes(point_to_check)
+
+ # Check if the points are inside the mbb
+ expected_inside_mbox = [True, False]
+ self.assertEqual(
+ inside_mbox[0],
+ expected_inside_mbox[0],
+ "The point (0,0,50) should be inside the minimal bounding box",
+ )
+ self.assertEqual(
+ inside_mbox[1],
+ expected_inside_mbox[1],
+ "The point (0,0,-50) should be outside the minimal bounding box",
+ )
+
+ # Check if the points are inside the cone
+ expected_inside_shape = [True, False]
+ self.assertEqual(
+ inside_shape[0],
+ expected_inside_shape[0],
+ "The point (0,0,50) should be inside the cone",
+ )
+ self.assertEqual(
+ inside_shape[1],
+ expected_inside_shape[1],
+ "The point (0,0,-50) should be outside the cone",
+ )
+
+ # Test generate_point_cloud method.
+ # The expected number of points is given by the volume of the cone divided by the voxel side
+ # to the third.
+ # The points should be inside the cone.
+ cone_height = np.linalg.norm(configuration["origin"] - configuration["apex"])
+ volume = (np.pi * cone_height * configuration["radius"] ** 2) / 3.0
+ self._check_points_inside(sc, volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
+
+ # Check rotation
+ expected_mbb = np.array(
+ [[-100.0, -100.0, 0.0], [100.0, 100.0, 100.0]], dtype=np.float64
+ )
+ cone.rotate(np.array([0.0, 1.0, 0.0]), np.pi / 2)
+ mbb = cone.find_mbb()
+ self.assertClose(mbb[0], expected_mbb[0])
+ self.assertClose(mbb[1], expected_mbb[1])
+
+ wireframe = np.array(sc.generate_wireframe())
+ self.assertEqual((3, 1, 30, 30), wireframe.shape)
+ self.assertTrue(
+ np.alltrue(
+ np.linalg.norm(
+ (wireframe[:, 0, :, 0] - cone.origin[..., np.newaxis])[
+ np.array([0, 1])
+ ],
+ axis=0,
+ )
+ - radius
+ <= 1e-5
+ )
+ )
+ for p in wireframe[:, 0, :, 0].T:
+ self.assertTrue(
+ 0
+ <= np.absolute(p[2] - cone.origin[2])
+ <= np.absolute(cone.apex[2] - cone.origin[2])
+ )
+
+ # Create ShapeComposition object, add a sphere and a cylinder and then test
+ def test_shape_composition(self):
+ config_sphere = dict(radius=10.0, origin=np.array([0, 0, 0], dtype=np.float64))
+ config_cylinder = dict(
+ top_center=np.array([0, 0, 0], dtype=np.float64),
+ radius=25.0,
+ origin=np.array([0, 0, -40], dtype=np.float64),
+ )
+
+ with self.assertRaises(RequirementError):
+ ShapesComposition(
+ dict(
+ shapes=[
+ dict(
+ type="sphere",
+ radius=10.0,
+ center=np.array([0, 0, 0], dtype=np.float64),
+ )
+ ],
+ labels=[["label1"], ["label2"]],
+ )
+ )
+
+ with self.assertRaises(RequirementError):
+ ShapesComposition(dict(shapes=[], labels=[["label1"], ["label2"]]))
+
+ # Create a ShapesComposition object
+ conf = dict(voxel_size=10, shapes=[], labels=[])
+ sc = ShapesComposition(conf)
+
+ # Build a sphere
+ sc.add_shape(Sphere(config_sphere), ["sphere"])
+
+ # Build a cylinder
+ sc.add_shape(Cylinder(config_cylinder), ["cylinder"])
+
+ # Check if shapes filtering by labels works
+ filtered_shape = sc.filter_by_labels(["sphere"])
+ self.assertIsInstance(filtered_shape._shapes[0], Sphere)
+
+ # Test the mininimal bounding box of the composition
+ # The expected mbb is [-25,-25,-40], [25,25,10]
+ mbb = sc.find_mbb()
+ expected_mbb = (
+ np.array([-25.0, -25.0, -40.0], dtype=np.float64),
+ np.array([25.0, 25.0, 10.0], dtype=np.float64),
+ )
+
+ self.assertClose(
+ mbb[0],
+ expected_mbb[0],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+ self.assertClose(
+ mbb[1],
+ expected_mbb[1],
+ "The minimal bounding box returned by find_mbb method is not the expected one",
+ )
+
+ # Test generate_point_cloud method for the composition of shapes.
+ # The expected number of points is given by the sum of the volumes the voxel side to the third
+ # The points should be inside the cone.
+ sphere_volume = 4.0 / 3.0 * np.pi * config_sphere["radius"] ** 3
+ cylinder_volume = (
+ np.pi
+ * np.linalg.norm(config_cylinder["origin"])
+ * config_cylinder["radius"] ** 2
+ )
+ total_volume = sphere_volume + cylinder_volume
+ self._check_points_inside(sc, total_volume, conf["voxel_size"])
+ self._check_translation(sc, expected_mbb)
diff --git a/tests/test_morphologies.py b/tests/test_morphologies.py
index 65a1aede..ea8fd223 100644
--- a/tests/test_morphologies.py
+++ b/tests/test_morphologies.py
@@ -384,6 +384,20 @@ def test_delete_point(self):
self.assertClose(branch.radii, np.array([1, 2]))
self.assertClose([0, 1], branch.labels)
+ def test_swap_axes(self):
+ points = np.arange(9).reshape(3, 3)
+ morpho = Morphology([Branch(points, np.array([0, 1, 2]))])
+ morpho.swap_axes(0, 0) # no swap
+ self.assertAll(morpho.points == points)
+ morpho.swap_axes(0, 1)
+ self.assertAll(
+ morpho.points == np.vstack([points[:, 1], points[:, 0], points[:, 2]]).T
+ )
+ with self.assertRaises(ValueError):
+ morpho.swap_axes(3, 1)
+ with self.assertRaises(ValueError):
+ morpho.swap_axes(1, -1)
+
class TestMorphologyLabels(NumpyTestCase, unittest.TestCase):
def test_labels(self):
diff --git a/tests/test_parsers.py b/tests/test_parsers.py
new file mode 100644
index 00000000..2caf0335
--- /dev/null
+++ b/tests/test_parsers.py
@@ -0,0 +1,255 @@
+import ast
+import pathlib
+import unittest
+from unittest.mock import patch
+
+from bsb.config.parsers import (
+ ConfigurationParser,
+ ParsesReferences,
+ get_configuration_parser,
+)
+from bsb.exceptions import ConfigurationWarning, FileReferenceError, PluginError
+
+
+def get_content(file: str):
+ return (pathlib.Path(__file__).parent / "data/configs" / file).read_text()
+
+
+class RefParserMock(ParsesReferences, ConfigurationParser):
+ data_description = "txt"
+ data_extensions = ("txt",)
+
+ def parse(self, content, path=None):
+ if isinstance(content, str):
+ content = ast.literal_eval(content)
+ return content, {"meta": path}
+
+ def generate(self, tree, pretty=False):
+ # Should not be called.
+ pass
+
+
+class RefParserMock2(ParsesReferences, ConfigurationParser):
+ data_description = "bla"
+ data_extensions = ("bla",)
+
+ def parse(self, content, path=None):
+ if isinstance(content, str):
+ content = content.replace("<", "{")
+ content = content.replace(">", "}")
+ content = ast.literal_eval(content)
+ return content, {"meta": path}
+
+ def generate(self, tree, pretty=False):
+ # Should not be called.
+ pass
+
+
+class TestParsersBasics(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.parser = RefParserMock()
+
+ def test_get_parser(self):
+ self.assertRaises(PluginError, get_configuration_parser, "doesntexist")
+
+ def test_parse_empty_doc(self):
+ tree, meta = self.parser.parse({})
+ self.assertEqual({}, tree, "'{}' parse should produce empty dict")
+
+ def assert_basics(self, tree, meta):
+ self.assertEqual(3, tree["list"][2], "Incorrectly parsed basic Txt")
+ self.assertEqual(
+ "just like that",
+ tree["nest me hard"]["oh yea"],
+ "Incorrectly parsed nested File",
+ )
+ self.assertEqual(
+ "", str(tree["list"])
+ )
+
+ def test_parse_basics(self):
+ # test from str
+ self.assert_basics(*self.parser.parse(get_content("basics.txt")))
+
+ # test from dict
+ content = {
+ "hello": "world",
+ "list": [1, 2, 3, "waddup"],
+ "nest me hard": {"oh yea": "just like that"},
+ }
+ self.assert_basics(*self.parser.parse(content))
+
+
+class TestFileRef(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.parser = RefParserMock()
+
+ def test_indoc_reference(self):
+ content = ast.literal_eval(get_content("indoc_reference.txt"))
+ tree, meta = self.parser.parse(content)
+ self.assertNotIn("$ref", tree["refs"]["whats the"], "Ref key not removed")
+ self.assertEqual("key", tree["refs"]["whats the"]["secret"])
+ self.assertEqual("is hard", tree["refs"]["whats the"]["nested secrets"]["vim"])
+ self.assertEqual("convoluted", tree["refs"]["whats the"]["nested secrets"]["and"])
+ # Checking str keys order.
+ self.assertEqual(
+ str(tree["refs"]["whats the"]["nested secrets"]),
+ "",
+ )
+ self.assertEqual(tree["refs"]["whats the"], tree["refs"]["omitted_doc"])
+ content["get"]["a"] = "secret"
+ with self.assertRaises(FileReferenceError, msg="Should raise 'ref not a dict'"):
+ tree, meta = self.parser.parse(content)
+
+ @patch("bsb.config.parsers.get_configuration_parser_classes")
+ def test_far_references(self, get_content_mock):
+ # Override get_configuration_parser to manually register RefParserMock
+ get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2}
+ content = {
+ "refs": {
+ "whats the": {"$ref": "basics.txt#/nest me hard"},
+ "and": {"$ref": "indoc_reference.txt#/refs/whats the"},
+ "far": {"$ref": "far/targetme.bla#/this/key"},
+ },
+ "target": {"for": "another"},
+ }
+ tree, meta = self.parser.parse(
+ content,
+ path=str(
+ (pathlib.Path(__file__).parent / "data" / "configs" / "interdoc_refs.txt")
+ ),
+ )
+ self.assertIn("was", tree["refs"]["far"])
+ self.assertEqual("in another folder", tree["refs"]["far"]["was"])
+ self.assertIn("oh yea", tree["refs"]["whats the"])
+ self.assertEqual("just like that", tree["refs"]["whats the"]["oh yea"])
+
+ @patch("bsb.config.parsers.get_configuration_parser_classes")
+ def test_double_ref(self, get_content_mock):
+ # Override get_configuration_parser to manually register RefParserMock
+ get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2}
+ tree, meta = self.parser.parse(
+ get_content("doubleref.txt"),
+ path=str(
+ (pathlib.Path(__file__).parent / "data" / "configs" / "doubleref.txt")
+ ),
+ )
+ # Only the latest ref is included because the literal_eval keeps only the latest value
+ # for similar keys
+ self.assertNotIn("oh yea", tree["refs"]["whats the"])
+ self.assertIn("for", tree["refs"]["whats the"])
+ self.assertIn("another", tree["refs"]["whats the"]["for"])
+
+ @patch("bsb.config.parsers.get_configuration_parser_classes")
+ def test_ref_str(self, get_content_mock):
+ # Override get_configuration_parser to manually register RefParserMock
+ get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2}
+ tree, meta = self.parser.parse(
+ get_content("doubleref.txt"),
+ path=str(
+ (pathlib.Path(__file__).parent / "data" / "configs" / "doubleref.txt")
+ ),
+ )
+ self.assertTrue(str(self.parser.references[0]).startswith("")
+ )
+
+ @patch("bsb.config.parsers.get_configuration_parser_classes")
+ def test_wrong_ref(self, get_content_mock):
+ # Override get_configuration_parser to manually register RefParserMock
+ get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2}
+
+ content = {"refs": {"whats the": {"$ref": "basics.txt#/oooooooooooooo"}}}
+ with self.assertRaises(FileReferenceError, msg="ref should not exist"):
+ self.parser.parse(
+ content,
+ path=str(
+ (pathlib.Path(__file__).parent / "data" / "configs" / "wrong_ref.txt")
+ ),
+ )
+
+
+class TestFileImport(unittest.TestCase):
+ @classmethod
+ def setUpClass(cls):
+ cls.parser = RefParserMock()
+
+ def test_indoc_import(self):
+ tree, meta = self.parser.parse(get_content("indoc_import.txt"))
+ self.assertEqual(["with", "importable"], list(tree["imp"].keys()))
+ self.assertEqual("are", tree["imp"]["importable"]["dicts"]["that"])
+
+ def test_indoc_import_list(self):
+ from bsb.config._parse_types import parsed_list
+
+ content = ast.literal_eval(get_content("indoc_import.txt"))
+ content["arr"]["with"] = ["a", "b", ["a", "c"]]
+ tree, meta = self.parser.parse(content)
+ self.assertEqual(["with", "importable"], list(tree["imp"].keys()))
+ self.assertEqual("a", tree["imp"]["with"][0])
+ self.assertEqual(parsed_list, type(tree["imp"]["with"][2]), "message")
+
+ def test_indoc_import_value(self):
+ content = ast.literal_eval(get_content("indoc_import.txt"))
+ content["arr"]["with"] = "a"
+ tree, meta = self.parser.parse(content)
+ self.assertEqual(["with", "importable"], list(tree["imp"].keys()))
+ self.assertEqual("a", tree["imp"]["with"])
+
+ def test_import_merge(self):
+ tree, meta = self.parser.parse(get_content("indoc_import_merge.txt"))
+ self.assertEqual(2, len(tree["imp"].keys()))
+ self.assertIn("importable", tree["imp"])
+ self.assertIn("with", tree["imp"])
+ self.assertEqual(
+ ["importable", "with"],
+ list(tree["imp"].keys()),
+ "Imported keys should follow on original keys",
+ )
+ self.assertEqual(4, tree["imp"]["importable"]["dicts"]["that"])
+ self.assertEqual("eh", tree["imp"]["importable"]["dicts"]["even"]["nested"])
+ self.assertEqual(["new", "list"], tree["imp"]["importable"]["dicts"]["with"])
+
+ def test_import_overwrite(self):
+ content = ast.literal_eval(get_content("indoc_import.txt"))
+ content["imp"]["importable"] = 10
+
+ with self.assertWarns(ConfigurationWarning) as warning:
+ tree, meta = self.parser.parse(content)
+ self.assertEqual(2, len(tree["imp"].keys()))
+ self.assertIn("importable", tree["imp"])
+ self.assertIn("with", tree["imp"])
+ self.assertEqual(
+ ["importable", "with"],
+ list(tree["imp"].keys()),
+ "Imported keys should follow on original keys",
+ )
+ self.assertEqual(10, tree["imp"]["importable"])
+
+ @patch("bsb.config.parsers.get_configuration_parser_classes")
+ def test_outdoc_import_merge(self, get_content_mock):
+ # Override get_configuration_parser to manually register RefParserMock
+ get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2}
+
+ file = "outdoc_import_merge.txt"
+ tree, meta = self.parser.parse(
+ get_content(file), path=pathlib.Path(__file__).parent / "data/configs" / file
+ )
+
+ expected = {
+ "with": {},
+ "importable": {
+ "dicts": {
+ "that": "are",
+ "even": {"nested": "eh"},
+ "with": ["new", "list"],
+ },
+ "diff": "added",
+ },
+ }
+ self.assertTrue(expected == tree["imp"])
diff --git a/tests/test_placement.py b/tests/test_placement.py
index 4b85e6e6..7c29c296 100644
--- a/tests/test_placement.py
+++ b/tests/test_placement.py
@@ -331,11 +331,12 @@ def test_particle_vd(self):
voxels=network.partitions.test_part.vs,
)
self.assertEqual(4, len(counts), "should have vector of counts per voxel")
- self.assertTrue(np.allclose([78, 16, 8, 27], counts, atol=1), "densities incorr")
+ # test rounded down values
+ self.assertTrue(np.allclose([78, 15, 7, 26], counts, atol=1), "densities incorr")
network.compile(clear=True)
ps = network.get_placement_set("test_cell")
- self.assertGreater(len(ps), 90)
- self.assertLess(len(ps), 130)
+ self.assertGreater(len(ps), 125) # rounded down values -1
+ self.assertLess(len(ps), 132) # rounded up values + 1
def _config_packing_fact(self):
return Configuration.default(
diff --git a/tests/test_util.py b/tests/test_util.py
index 7dae6430..bdd7f995 100644
--- a/tests/test_util.py
+++ b/tests/test_util.py
@@ -10,7 +10,7 @@
from scipy.spatial.transform import Rotation
from bsb import FileDependency, NeuroMorphoScheme, Scaffold
-from bsb._util import rotation_matrix_from_vectors
+from bsb._util import assert_samelen, rotation_matrix_from_vectors
class TestNetworkUtil(
@@ -93,3 +93,12 @@ def test_nm_scheme_down(self):
file.get_meta()
finally:
NeuroMorphoScheme._nm_url = url
+
+
+class TestAssertSameLength(unittest.TestCase):
+ def test_same_length(self):
+ assert_samelen([1, 2, 3], [4, 5, 6])
+ with self.assertRaises(AssertionError):
+ assert_samelen([1, 2], [2])
+ assert_samelen([[1, 2]], [3])
+ assert_samelen([], [])