diff --git a/.github/devops/generate_public_api.py b/.github/devops/generate_public_api.py index df32c072..378bfa23 100644 --- a/.github/devops/generate_public_api.py +++ b/.github/devops/generate_public_api.py @@ -36,6 +36,10 @@ def get_public_api_map(): if isinstance(el, ast.Constant) ] for api in module_api: + if api in public_api_map: + raise RuntimeError( + f"Duplicate api key: bsb.{module}.{api} and bsb.{public_api_map[api]}.{api}" + ) public_api_map[api] = module return public_api_map diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 748ece20..705ff3cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,4 +7,10 @@ repos: rev: 5.12.0 hooks: - id: isort - name: isort (python) \ No newline at end of file + name: isort (python) + - repo: local + hooks: + - id: api-test + name: api-test + entry: python3 .github/devops/generate_public_api.py + language: system diff --git a/CHANGELOG b/CHANGELOG index 729feaab..609fb346 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,19 @@ +# 4.1.1 +* Fix reference of file_ref during configuration parsing when importing nodes. +* Use a more strict rule for Jobs enqueuing. +* Use certifi to fix ssl certificate issues. + +# 4.1.0 +* Added `ParsesReferences` mixin from bsb-json to allow reference and import in configuration files. + This includes also a recursive parsing of the configuration files. +* Added `swap_axes` function in morphologies +* Added API test in pre-commit config and fix duplicate entries. +* Fix `PackageRequirement`, `ConfigurationListAttribute`, and `ConfigurationDictAttribute` + inverse functions +* Refactor `CodeDependencyNode` module attribute to be either a module like string or a path string +* Fix of assert_same_len +* Fix of test_particle_vd + # 4.0.0 - Too much to list ## 40.0.0a32 diff --git a/README.md b/README.md index dcf5db0f..09503cef 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,7 @@ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Documentation Status](https://readthedocs.org/projects/bsb/badge/?version=latest)](https://bsb.readthedocs.io/en/latest/?badge=latest) -[![Build Status](https://travis-ci.com/dbbs-lab/bsb.svg?branch=master)](https://travis-ci.com/dbbs-lab/bsb) -[![codecov](https://codecov.io/gh/dbbs-lab/bsb/branch/master/graph/badge.svg)](https://codecov.io/gh/dbbs-lab/bsb) +[![Build Status](https://travis-ci.com/dbbs-lab/bsb-core.svg?branch=main)](https://travis-ci.com/dbbs-lab/bsb-core) +[![codecov](https://codecov.io/gh/dbbs-lab/bsb-core/branch/main/graph/badge.svg)](https://codecov.io/gh/dbbs-lab/bsb-core)

:closed_book: Read the documentation on https://bsb.readthedocs.io/en/latest

@@ -25,14 +25,14 @@ Any package in the BSB ecosystem can be installed from PyPI through `pip`. Most will want to install the main [bsb](https://pypi.org/project/bsb/) framework: ``` -pip install "bsb~=4.0" +pip install "bsb~=4.1" ``` Advanced users looking to control install an unconventional combination of plugins might be better of installing just this package, and the desired plugins: ``` -pip install "bsb-core~=4.0" +pip install "bsb-core~=4.1" ``` Note that installing `bsb-core` does not come with any plugins installed and the usually diff --git a/bsb/__init__.py b/bsb/__init__.py index 8d4f09a9..5b9bfe5f 100644 --- a/bsb/__init__.py +++ b/bsb/__init__.py @@ -7,7 +7,7 @@ install the `bsb` package instead. """ -__version__ = "4.0.1" +__version__ = "4.1.1" import functools import importlib @@ -142,7 +142,7 @@ def __dir__(): BaseCommand: typing.Type["bsb.cli.commands.BaseCommand"] BidirectionalContact: typing.Type["bsb.postprocessing.BidirectionalContact"] BootError: typing.Type["bsb.exceptions.BootError"] -BoxTree: typing.Type["bsb.voxels.BoxTree"] +BoxTree: typing.Type["bsb.trees.BoxTree"] BoxTreeInterface: typing.Type["bsb.trees.BoxTreeInterface"] Branch: typing.Type["bsb.morphologies.Branch"] BranchLocTargetting: typing.Type["bsb.simulation.targetting.BranchLocTargetting"] @@ -216,6 +216,8 @@ def __dir__(): ExternalSourceError: typing.Type["bsb.exceptions.ExternalSourceError"] FileDependency: typing.Type["bsb.storage._files.FileDependency"] FileDependencyNode: typing.Type["bsb.storage._files.FileDependencyNode"] +FileImportError: typing.Type["bsb.exceptions.FileImportError"] +FileReferenceError: typing.Type["bsb.exceptions.FileReferenceError"] FileScheme: typing.Type["bsb.storage._files.FileScheme"] FileStore: typing.Type["bsb.storage.interfaces.FileStore"] FixedIndegree: typing.Type["bsb.connectivity.general.FixedIndegree"] @@ -292,6 +294,7 @@ def __dir__(): ParameterError: typing.Type["bsb.exceptions.ParameterError"] ParameterValue: typing.Type["bsb.simulation.parameter.ParameterValue"] ParserError: typing.Type["bsb.exceptions.ParserError"] +ParsesReferences: typing.Type["bsb.config.parsers.ParsesReferences"] Partition: typing.Type["bsb.topology.partition.Partition"] PlacementError: typing.Type["bsb.exceptions.PlacementError"] PlacementIndications: typing.Type["bsb.cell_types.PlacementIndications"] diff --git a/bsb/_util.py b/bsb/_util.py index a3759661..0cb530b7 100644 --- a/bsb/_util.py +++ b/bsb/_util.py @@ -95,7 +95,7 @@ def assert_samelen(*args): """ len_ = None assert all( - (len_ := len(arg) if len_ is None else len(arg)) == len_ for arg in args + ((len_ := len(arg)) if len_ is None else len(arg)) == len_ for arg in args ), "Input arguments should be of same length." diff --git a/bsb/config/_attrs.py b/bsb/config/_attrs.py index 4c53c1fe..9aec9e56 100644 --- a/bsb/config/_attrs.py +++ b/bsb/config/_attrs.py @@ -3,6 +3,7 @@ """ import builtins +from functools import wraps import errr @@ -476,7 +477,7 @@ def __set__(self, instance, value): self.attr_name, ) from e self.flag_dirty(instance) - # The value was cast to its intented type and the new value can be set. + # The value was cast to its intended type and the new value can be set. self.fset(instance, value) root = _strict_root(instance) if _is_booted(root): @@ -687,7 +688,15 @@ def fill(self, value, _parent, _key=None): def _set_type(self, type, key=None): self.child_type = super()._set_type(type, key=False) - return self.fill + + @wraps(self.fill) + def wrapper(*args, **kwargs): + return self.fill(*args, **kwargs) + + # Expose children __inv__ function if it exists + if hasattr(self.child_type, "__inv__"): + setattr(wrapper, "__inv__", self.child_type.__inv__) + return wrapper def tree(self, instance): val = _getattr(instance, self.attr_name) @@ -838,7 +847,15 @@ def fill(self, value, _parent, _key=None): def _set_type(self, type, key=None): self.child_type = super()._set_type(type, key=False) - return self.fill + + @wraps(self.fill) + def wrapper(*args, **kwargs): + return self.fill(*args, **kwargs) + + # Expose children __inv__ function if it exists + if hasattr(self.child_type, "__inv__"): + setattr(wrapper, "__inv__", self.child_type.__inv__) + return wrapper def tree(self, instance): val = _getattr(instance, self.attr_name).items() diff --git a/bsb/config/_config.py b/bsb/config/_config.py index 7c8464cb..599b806e 100644 --- a/bsb/config/_config.py +++ b/bsb/config/_config.py @@ -7,7 +7,7 @@ from ..cell_types import CellType from ..connectivity import ConnectionStrategy from ..placement import PlacementStrategy -from ..postprocessing import AfterPlacementHook +from ..postprocessing import AfterConnectivityHook, AfterPlacementHook from ..simulation.simulation import Simulation from ..storage._files import ( CodeDependencyNode, @@ -132,8 +132,8 @@ class Configuration: """ Network connectivity strategies """ - after_connectivity: cfgdict[str, AfterPlacementHook] = config.dict( - type=AfterPlacementHook, + after_connectivity: cfgdict[str, AfterConnectivityHook] = config.dict( + type=AfterConnectivityHook, ) simulations: cfgdict[str, Simulation] = config.dict( type=Simulation, diff --git a/bsb/config/_parse_types.py b/bsb/config/_parse_types.py new file mode 100644 index 00000000..f16507c0 --- /dev/null +++ b/bsb/config/_parse_types.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from .. import warn +from ..exceptions import ConfigurationWarning, FileImportError + + +class parsed_node: + _key = None + """Key to reference the node""" + _parent: parsed_node = None + """Parent node""" + + def location(self): + return "/" + "/".join(str(part) for part in self._location_parts([])) + + def _location_parts(self, carry): + parent = self + while (parent := parent._parent) is not None: + if parent._parent is not None: + carry.insert(0, parent._key) + carry.append(self._key or "") + return carry + + def __str__(self): + return f"" + + def __repr__(self): + return super().__str__() + + +def _traverse_wrap(node, iter): + for key, value in iter: + if type(value) in recurse_handlers: + value, iter = recurse_handlers[type(value)](value, node) + value._key = key + value._parent = node + node[key] = value + _traverse_wrap(value, iter) + + +class parsed_dict(dict, parsed_node): + def merge(self, other): + """ + Recursively merge the values of another dictionary into us + """ + for key, value in other.items(): + if key in self and isinstance(self[key], dict) and isinstance(value, dict): + if not isinstance(self[key], parsed_dict): # pragma: nocover + self[key] = d = parsed_dict(self[key]) + d._key = key + d._parent = self + self[key].merge(value) + elif isinstance(value, dict): + self[key] = d = parsed_dict(value) + d._key = key + d._parent = self + _traverse_wrap(d, d.items()) + else: + if isinstance(value, list): + value = parsed_list(value) + value._key = key + value._parent = self + self[key] = value + + def rev_merge(self, other): + """ + Recursively merge ourselves onto another dictionary + """ + m = parsed_dict(other) + _traverse_wrap(m, m.items()) + m.merge(self) + self.clear() + self.update(m) + for v in self.values(): + if hasattr(v, "_parent"): + v._parent = self + + +class parsed_list(list, parsed_node): + pass + + +def _prep_dict(node, parent): + return parsed_dict(node), node.items() + + +def _prep_list(node, parent): + return parsed_list(node), enumerate(node) + + +recurse_handlers = { + dict: _prep_dict, + parsed_dict: _prep_dict, + list: _prep_list, + parsed_list: _prep_list, +} + + +class file_ref: + def __init__(self, node, doc, ref): + self.node = node + self.doc = doc + self.ref = ref + self.key_path = node.location() + + def resolve(self, parser, target): + del self.node["$ref"] + self.node.rev_merge(target) + + def __str__(self): + return "".format(((self.doc + "#") if self.doc else "") + self.ref) + + +class file_imp(file_ref): + def __init__(self, node, doc, ref, values): + super().__init__(node, doc, ref) + self.values = values + + def resolve(self, parser, target): + del self.node["$import"] + for key in self.values: + if key not in target: + raise FileImportError( + "'{}' does not exist in import node '{}'".format(key, self.ref) + ) + if isinstance(target[key], dict): + imported = parsed_dict() + imported.merge(target[key]) + imported._key = key + imported._parent = self.node + if key in self.node: + if isinstance(self.node[key], dict): + imported.merge(self.node[key]) + else: + warn( + f"Importkey '{key}' of {self} is ignored because the parent" + f" already contains a key '{key}'" + f" with value '{self.node[key]}'.", + ConfigurationWarning, + stacklevel=3, + ) + continue + self.node[key] = imported + self._fix_references(self.node[key], parser) + elif isinstance(target[key], list): + imported, iter = _prep_list(target[key], self.node) + imported._key = key + imported._parent = self.node + self.node[key] = imported + self._fix_references(self.node[key], parser) + _traverse_wrap(imported, iter) + else: + self.node[key] = target[key] + + def _fix_references(self, node, parser): + # fix parser's references after the import. + if hasattr(parser, "references"): + for ref in parser.references: + node_loc = node.location() + if node_loc in ref.key_path: + # need to update the reference + # we descend the tree from the node until we reach the ref + # It should be here because of the merge. + loc_node = node + while node_loc != ref.key_path: + key = ref.key_path.split(node_loc, 1)[-1].split("/", 1)[-1] + if key not in loc_node: # pragma: nocover + raise ParserError( + f"Reference {ref.key_path} not found in {node_loc}. " + f"Should have been merged." + ) + loc_node = node[key] + node_loc += "/" + key + ref.node = loc_node diff --git a/bsb/config/parsers.py b/bsb/config/parsers.py index fb455e1f..5b5c6061 100644 --- a/bsb/config/parsers.py +++ b/bsb/config/parsers.py @@ -1,19 +1,195 @@ import abc import functools +import os -from ..exceptions import PluginError +from ..exceptions import FileReferenceError, PluginError +from ._parse_types import file_imp, file_ref, parsed_dict, recurse_handlers class ConfigurationParser(abc.ABC): @abc.abstractmethod - def parse(self, content, path=None): + def parse(self, content, path=None): # pragma: nocover + """ + Parse configuration file content. + + :param content: str or dict content of the file to parse + :param path: path to the file containing the configuration. + :return: configuration tree and metadata attached as dictionaries + """ pass @abc.abstractmethod - def generate(self, tree, pretty=False): + def generate(self, tree, pretty=False): # pragma: nocover + """ + Generate a string representation of the configuration tree (dictionary). + + :param dict tree: configuration tree + :param bool pretty: if True, will add indentation to the output string + :return: str representation of the configuration tree + :rtype: str + """ pass +class ParsesReferences: + """ + Mixin to decorate parse function of ConfigurationParser. + Allows for imports and references inside configuration files. + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + parse = cls.parse + + def parse_with_references(self, content, path=None): + """Traverse the parsed tree and resolve any `$ref` and `$import`""" + content, meta = parse(self, content, path) + content = parsed_dict(content) + self.root = content + self.path = path or os.getcwd() + self.is_doc = path and not os.path.isdir(path) + self.references = [] + self.documents = {} + self._traverse(content, content.items()) + self.resolved_documents = {} + self._resolve_documents() + self._resolve_references() + return content, meta + + cls.parse = parse_with_references + + def _traverse(self, node, iter): + # Iterates over all values in `iter` and checks for import keys, recursion or refs + # Also wraps all nodes in their `parsed_*` counterparts. + for key, value in iter: + if self._is_import(key): + self._store_import(node) + elif type(value) in recurse_handlers: + # The recurse handlers wrap the dicts and lists and return appropriate + # iterators for them. + value, iter = recurse_handlers[type(value)](value, node) + # Set some metadata on the wrapped recursable objects. + value._key = key + value._parent = node + # Overwrite the reference to the original object with a reference to the + # wrapped object. + node[key] = value + # Recurse a level deeper + self._traverse(value, iter) + elif self._is_reference(key): + self._store_reference(node, value) + + def _is_reference(self, key): + return key == "$ref" + + def _is_import(self, key): + return key == "$import" + + def _get_ref_document(self, ref, base=None): + if "#" not in ref or ref.split("#")[0] == "": + return None + doc = ref.split("#")[0] + if not os.path.isabs(doc): + if not base: + # reference should be relative to the current configuration file + # to avoid recurrence issues. + base = os.path.dirname(self.path) + elif not os.path.isdir(base): + base = os.path.dirname(base) + if not os.path.exists(base): + raise IOError("Can't find reference directory '{}'".format(base)) + doc = os.path.abspath(os.path.join(base, doc)) + return doc + + @staticmethod + def _get_absolute_ref(node, ref): + ref = ref.split("#")[-1] + if ref.startswith("/"): + path = ref + else: + path = os.path.join(node.location(), ref) + return os.path.normpath(path).replace(os.path.sep, "/") + + def _store_reference(self, node, ref): + # Analyzes the reference and creates a ref object from the given data + doc = self._get_ref_document(ref, self.path) + ref = self._get_absolute_ref(node, ref) + if doc not in self.documents: + self.documents[doc] = set() + self.documents[doc].add(ref) + self.references.append(file_ref(node, doc, ref)) + + def _store_import(self, node): + # Analyzes the import node and creates a ref object from the given data + imp = node["$import"] + ref = imp["ref"] + doc = self._get_ref_document(ref) + ref = self._get_absolute_ref(node, ref) + if doc not in self.documents: + self.documents[doc] = set() + self.documents[doc].add(ref) + if "values" not in imp: + e = RuntimeError(f"Import node {node} is missing a 'values' list.") + e._bsbparser_show_user = True + raise e + self.references.append(file_imp(node, doc, ref, imp["values"])) + + def _resolve_documents(self): + # Iterates over the list of stored documents parses them and fetches the content + # of each reference node. + for file, refs in self.documents.items(): + if file is None: + content = self.root + else: + from . import _try_parsers + + parser_classes = get_configuration_parser_classes() + ext = file.split(".")[-1] + with open(file, "r") as f: + content = f.read() + _, content, _ = _try_parsers(content, parser_classes, ext, path=file) + try: + self.resolved_documents[file] = self._resolve_document(content, refs) + except FileReferenceError as jre: + if not file: + raise + raise FileReferenceError( + str(jre) + " in document '{}'".format(file) + ) from None + + def _resolve_document(self, content, refs): + resolved = {} + for ref in refs: + resolved[ref] = self._fetch_reference(content, ref) + return resolved + + def _fetch_reference(self, content, ref): + parts = [p for p in ref.split("/")[1:] if p] + n = content + loc = "" + for part in parts: + loc += "/" + part + try: + n = n[part] + except KeyError: + raise FileReferenceError( + "'{}' in File reference '{}' does not exist".format(loc, ref) + ) from None + if not isinstance(n, dict): + raise FileReferenceError( + "File references can only point to dictionaries. '{}' is a {}".format( + "{}' in '{}".format(loc, ref) if loc != ref else ref, + type(n).__name__, + ) + ) + return n + + def _resolve_references(self): + for ref in self.references: + target = self.resolved_documents[ref.doc][ref.ref] + ref.resolve(self, target) + + @functools.cache def get_configuration_parser_classes(): from ..plugins import discover @@ -36,6 +212,7 @@ def get_configuration_parser(parser, **kwargs): __all__ = [ "ConfigurationParser", + "ParsesReferences", "get_configuration_parser", "get_configuration_parser_classes", ] diff --git a/bsb/config/types.py b/bsb/config/types.py index aee713d4..ceea9315 100644 --- a/bsb/config/types.py +++ b/bsb/config/types.py @@ -740,6 +740,40 @@ def requirement(section): return requirement +def same_size(*list_attrs, required=True): + """ + Requirement handler for list attributes that should have the same size. + + :param list_attrs: The keys of the list attributes. + :type list_attrs: str + :param required: Whether at least one of the keys is required + :type required: bool + :returns: Requirement function + :rtype: Callable + """ + listed = ", ".join(f"`{m}`" for m in list_attrs[:-1]) + if len(list_attrs) > 1: + listed += f" {{}} `{list_attrs[-1]}`" + + def requirement(section): + common_size = -1 + count = 0 + for m in list_attrs: + if m in section: + v = builtins.list(section[m]) + if len(v) != common_size and common_size >= 0: + err_msg = f"The {listed} attributes should have the same size." + raise RequirementError(err_msg) + common_size = len(v) + count += 1 + if not count == len(list_attrs) and required: + err_msg = f"The {listed} attributes are required." + raise RequirementError(err_msg) + return False + + return requirement + + def shortform(): def requirement(section): return not section.is_shortform @@ -755,7 +789,12 @@ class ndarray(TypeHandler): :rtype: Callable """ + def __init__(self, dtype=None): + self.dtype = dtype + def __call__(self, value): + if self.dtype is not None: + return np.array(value, copy=False, dtype=self.dtype) return np.array(value, copy=False) @property @@ -780,14 +819,16 @@ class PackageRequirement(TypeHandler): def __call__(self, value): from packaging.requirements import Requirement - return Requirement(value) + requirement = Requirement(value) + requirement._cfg_inv = value + return requirement @property def __name__(self): return "package requirement" def __inv__(self, value): - return str(value) + return getattr(value, "_cfg_inv", builtins.str(value)) def __hint__(self): return "numpy==1.24.0" diff --git a/bsb/connectivity/__init__.py b/bsb/connectivity/__init__.py index 42dda4d3..19b20c2d 100644 --- a/bsb/connectivity/__init__.py +++ b/bsb/connectivity/__init__.py @@ -5,4 +5,5 @@ # isort: on from .detailed import * from .general import * +from .geometric import * from .import_ import CsvImportConnectivity diff --git a/bsb/connectivity/detailed/shared.py b/bsb/connectivity/detailed/shared.py index 928bf2a2..3754ced1 100644 --- a/bsb/connectivity/detailed/shared.py +++ b/bsb/connectivity/detailed/shared.py @@ -1,4 +1,3 @@ -from functools import cache from itertools import chain import numpy as np @@ -15,8 +14,8 @@ class Intersectional: def get_region_of_interest(self, chunk): post_ps = [ct.get_placement_set() for ct in self.postsynaptic.cell_types] - lpre, upre = self._get_rect_ext(tuple(chunk.dimensions), True) - lpost, upost = self._get_rect_ext(tuple(chunk.dimensions), False) + lpre, upre = self.presynaptic._get_rect_ext(tuple(chunk.dimensions)) + lpost, upost = self.postsynaptic._get_rect_ext(tuple(chunk.dimensions)) # Get the `np.arange`s between bounds offset by the chunk position, to be used in # `np.meshgrid` below. bounds = list( @@ -42,27 +41,6 @@ def get_region_of_interest(self, chunk): size = next(iter(self._occ_chunks)).dimensions return [t for c in clist if (t := Chunk(c, size)) in self._occ_chunks] - @cache - def _get_rect_ext(self, chunk_size, pre_post_flag): - if pre_post_flag: - types = self.presynaptic.cell_types - loader = self.presynaptic.morpho_loader - else: - types = self.postsynaptic.cell_types - loader = self.postsynaptic.morpho_loader - ps_list = [ct.get_placement_set() for ct in types] - ms_list = [loader(ps) for ps in ps_list] - if not sum(map(len, ms_list)): - # No cells placed, return smallest possible RoI. - return [np.array([0, 0, 0]), np.array([0, 0, 0])] - metas = list(chain.from_iterable(ms.iter_meta(unique=True) for ms in ms_list)) - # TODO: Combine morphology extension information with PS rotation information. - # Get the chunk coordinates of the boundaries of this chunk convoluted with the - # extension of the intersecting morphologies. - lbounds = np.min([m["ldc"] for m in metas], axis=0) // chunk_size - ubounds = np.max([m["mdc"] for m in metas], axis=0) // chunk_size - return lbounds, ubounds - def candidate_intersection(self, target_coll, candidate_coll): target_cache = [ (tset.cell_type, tset, tset.load_boxes()) for tset in target_coll.placement diff --git a/bsb/connectivity/geometric/__init__.py b/bsb/connectivity/geometric/__init__.py new file mode 100644 index 00000000..75e397bf --- /dev/null +++ b/bsb/connectivity/geometric/__init__.py @@ -0,0 +1,4 @@ +from .geometric_shapes import * +from .morphology_shape_intersection import MorphologyToShapeIntersection +from .shape_morphology_intersection import ShapeToMorphologyIntersection +from .shape_shape_intersection import ShapeHemitype, ShapeToShapeIntersection diff --git a/bsb/connectivity/geometric/geometric_shapes.py b/bsb/connectivity/geometric/geometric_shapes.py new file mode 100644 index 00000000..387596d2 --- /dev/null +++ b/bsb/connectivity/geometric/geometric_shapes.py @@ -0,0 +1,1290 @@ +from __future__ import annotations + +import abc +from typing import List, Tuple + +import numpy as np +from scipy.interpolate import interp1d +from scipy.spatial.transform import Rotation as R + +from ... import config +from ...config import types + + +def _reshape_vectors(rot_pts, x, y, z): + xrot = rot_pts[:, 0].reshape(x.shape) + yrot = rot_pts[:, 1].reshape(y.shape) + zrot = rot_pts[:, 2].reshape(z.shape) + + return xrot, yrot, zrot + + +def rotate_3d_mesh_by_vec( + x: np.ndarray, y: np.ndarray, z: np.ndarray, rot_versor: np.ndarray, angle: float +): + """ + Rotate meshgrid points according to a rotation versor and angle. + + :param numpy.ndarray[numpy.ndarray[float]] x: x coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] y: y coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] z: z coordinate points of the meshgrid + :param numpy.ndarray[float] rot_versor: vector representing rotation versor + :param float angle: rotation angle in radian + :return: Rotated x, y, z coordinate points + :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]] + """ + + # Arrange point coordinates in shape (N, 3) for vectorized processing + pts = np.array([x.ravel(), y.ravel(), z.ravel()]).transpose() + + # Create and apply rotation + rot = R.from_rotvec(rot_versor * angle) + rot_pts = rot.apply(pts) + + # return to original shape of meshgrid + return _reshape_vectors(rot_pts, x, y, z) + + +def translate_3d_mesh_by_vec( + x: np.ndarray, y: np.ndarray, z: np.ndarray, t_vec: np.ndarray +): + """ + Translate meshgrid points according to a 3d vector. + + :param numpy.ndarray[numpy.ndarray[float]] x: x coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] y: y coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] z: z coordinate points of the meshgrid + :param numpy.ndarray[float] t_vec: translation vector + :return: Translated x, y, z coordinate points + :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]] + """ + + # Arrange point coordinates in shape (N, 3) for vectorized processing + pts = np.array([x.ravel(), y.ravel(), z.ravel()]).transpose() + pts = pts + t_vec + # return to original shape of meshgrid + return _reshape_vectors(pts, x, y, z) + + +def rotate_3d_mesh_by_rot_mat( + x: np.ndarray, y: np.ndarray, z: np.ndarray, rot_mat: np.ndarray +): + """ + Rotate meshgrid points according to a rotation matrix. + + :param numpy.ndarray[numpy.ndarray[float]] x: x coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] y: y coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] z: z coordinate points of the meshgrid + :param numpy.ndarray[numpy.ndarray[float]] rot_mat: rotation matrix, shape (3,3) + :return: Rotated x, y, z coordinate points + :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]] + """ + + # Arrange point coordinates in shape (N, 3) for vectorized processing + pts = np.array([x.ravel(), y.ravel(), z.ravel()]).transpose() + + # Create and apply rotation + rot = R.from_matrix(rot_mat) + rot_pts = rot.apply(pts) + + # return to original shape of meshgrid + return _reshape_vectors(rot_pts, x, y, z) + + +def _surface_resampling( + surface_function, + theta_min=0, + theta_max=2 * np.pi, + phi_min=0, + phi_max=np.pi, + precision=25, +): + # first sampling to estimate surface distribution + theta, phi = np.meshgrid( + np.linspace(theta_min, theta_max, precision), + np.linspace(phi_min, phi_max, precision), + ) + coords = surface_function(theta, phi) + + # estimate surfaces, decomposing it into parallelograms along theta and phi + delta_t_temp = np.diff(coords, axis=2) + delta_u_temp = np.diff(coords, axis=1) + delta_t = np.zeros(coords.shape) + delta_u = np.zeros(coords.shape) + delta_t[: coords.shape[0], : coords.shape[1], 1 : coords.shape[2]] = delta_t_temp + delta_u[: coords.shape[0], 1 : coords.shape[1], : coords.shape[2]] = delta_u_temp + delta_S = np.linalg.norm(np.cross(delta_t, delta_u, 0, 0), axis=2) + cum_S_t = np.cumsum(delta_S.sum(axis=0)) + cum_S_u = np.cumsum(delta_S.sum(axis=1)) + return theta, phi, cum_S_t, cum_S_u + + +def uniform_surface_sampling( + n_points, + surface_function, + theta_min=0, + theta_max=2 * np.pi, + phi_min=0, + phi_max=np.pi, + precision=25, +): + """ + Uniform-like random sampling of polar coordinates based on surface estimation. + This sampling is useful on elliptic surfaces (e.g. sphere). + Algorithm based on https://github.com/maxkapur/param_tools + + :param int n_points: number of points to sample + :param Callable[..., numpy.ndarray[float]] surface_function: function converting polar + coordinates into cartesian coordinates + :param int precision: size of grid used to estimate function surface + """ + + theta, phi, cum_S_t, cum_S_u = _surface_resampling( + surface_function, theta_min, theta_max, phi_min, phi_max, precision + ) + # resample along the cumulative surface to uniformize point distribution + # equivalent to a multinomial sampling + sampled_t = np.random.rand(n_points) * cum_S_t[-1] + sampled_u = np.random.rand(n_points) * cum_S_u[-1] + sampled_t = interp1d(cum_S_t, theta[0, :])(sampled_t) + sampled_u = interp1d(cum_S_u, phi[:, 0])(sampled_u) + + return surface_function(sampled_t, sampled_u) + + +def uniform_surface_wireframe( + n_points_1, + n_points_2, + surface_function, + theta_min=0, + theta_max=2 * np.pi, + phi_min=0, + phi_max=np.pi, + precision=25, +): + """ + Uniform-like meshgrid of size (n_point_1, n_points_2) of polar coordinates based on surface + estimation. + This meshgrid is useful on elliptic surfaces (e.g. sphere). + Algorithm based on https://github.com/maxkapur/param_tools + + :param Callable[..., numpy.ndarray[float]] surface_function: function converting polar + coordinates into cartesian coordinates + :param int precision: size of grid used to estimate function surface + """ + + theta, phi, cum_S_t, cum_S_u = _surface_resampling( + surface_function, theta_min, theta_max, phi_min, phi_max, precision + ) + sampled_t = np.linspace(0, cum_S_t[-1], n_points_1) + sampled_u = np.linspace(0, cum_S_u[-1], n_points_2) + sampled_t = interp1d(cum_S_t, theta[0, :])(sampled_t) + sampled_u = interp1d(cum_S_u, phi[:, 0])(sampled_u) + sampled_t, sampled_u = np.meshgrid(sampled_t, sampled_u) + return surface_function(sampled_t, sampled_u) + + +def _get_prod_angle_vector(hv, z_versor=np.array([0, 0, 1])): + """ + Calculate the cross product and the arc cosines angle between two vectors. + + :param numpy.ndarray hv: vector to rotate + :param numpy.ndarray z_versor: reference vector + :return: cross product and arc cosines angle + :rtype: Tuple[numpy.ndarray, numpy.ndarray] + """ + + return np.cross(z_versor, hv), np.arccos(np.dot(hv, z_versor)) + + +def _get_rotation_vector(hv, z_versor=np.array([0, 0, 1]), positive_angle=True): + """ + Calculate the rotation vector between two vectors. + + :param numpy.ndarray hv: vector to rotate + :param numpy.ndarray z_versor: reference vector + :param bool positive_angle: if False, the angle is inverted + :return: rotation vector + :rtype: scipy.spatial.transform.Rotation + """ + + perp, angle = _get_prod_angle_vector(hv, z_versor) + angle = angle if positive_angle else -angle + rot = R.from_rotvec(perp * angle) + return rot + + +def _rotate_by_coord(x, y, z, hv, origin, test_hv=False): + perp, angle = _get_prod_angle_vector(hv / np.linalg.norm(hv)) + + x, y, z = rotate_3d_mesh_by_vec(x, y, z, perp, angle) + if test_hv and hv[2] < 0: + z = -z + return translate_3d_mesh_by_vec(x, y, z, origin) + + +def _get_extrema_after_rot(extrema, origin, top_center): + height = np.linalg.norm(top_center - origin) + rot = _get_rotation_vector((top_center - origin) / height) + + for i, pt in enumerate(extrema): + extrema[i] = rot.apply(pt) + + return np.min(extrema, axis=0) + origin, np.max(extrema, axis=0) + origin + + +def inside_mbox( + points: np.ndarray[float], + mbb_min: np.ndarray[float], + mbb_max: np.ndarray[float], +): + """ + Check if the points given in input are inside the minimal bounding box. + + :param numpy.ndarray points: An array of 3D points. + :param numpy.ndarray mbb_min: 3D point representing the lowest coordinate of the + minimal bounding box. + :param numpy.ndarray mbb_max: 3D point representing the highest coordinate of the + minimal bounding box. + :return: A bool np.ndarray specifying whether each point of the input array is inside the + minimal bounding box or not. + :rtype: numpy.ndarray[bool] + """ + inside = ( + (points[:, 0] > mbb_min[0]) + & (points[:, 0] < mbb_max[0]) + & (points[:, 1] > mbb_min[1]) + & (points[:, 1] < mbb_max[1]) + & (points[:, 2] > mbb_min[2]) + & (points[:, 2] < mbb_max[2]) + ) + + return inside + + +@config.dynamic(attr_name="type", default="shape", auto_classmap=True) +class GeometricShape(abc.ABC): + """ + Base class for geometric shapes + """ + + epsilon = config.attr(type=float, required=False, default=1.0e-3) + """Tolerance value to compare coordinates.""" + + def __init__(self, **kwargs): + self.mbb_min, self.mbb_max = self.find_mbb() + + @abc.abstractmethod + def find_mbb(self): # pragma: no cover + """ + Compute the minimum bounding box surrounding the shape. + """ + pass + + def check_mbox(self, points: np.ndarray[float]): + """ + Check if the points given in input are inside the minimal bounding box. + + :param numpy.ndarray points: A cloud of points. + :return: A bool np.ndarray specifying whether each point of the input array is inside the + minimal bounding box or not. + :rtype: numpy.ndarray + """ + return inside_mbox(points, self.mbb_min, self.mbb_max) + + @abc.abstractmethod + def get_volume(self): # pragma: no cover + """ + Get the volume of the geometric shape. + :return: The volume of the geometric shape. + :rtype: float + """ + pass + + @abc.abstractmethod + def translate(self, t_vector: np.ndarray[float]): # pragma: no cover + """ + Translate the geometric shape by the vector t_vector. + + :param numpy.ndarray t_vector: The displacement vector + """ + pass + + @abc.abstractmethod + def rotate(self, r_versor: np.ndarray[float], angle: float): # pragma: no cover + """ + Rotate all the shapes around r_versor, which is a versor passing through the origin, + by the specified angle. + + :param r_versor: A versor specifying the rotation axis. + :type r_versor: numpy.ndarray[float] + :param float angle: the rotation angle, in radians. + """ + pass + + @abc.abstractmethod + def generate_point_cloud(self, npoints: int): # pragma: no cover + """ + Generate a point cloud made by npoints points. + + :param int npoints: The number of points to generate. + :return: a (npoints x 3) numpy array. + :rtype: numpy.ndarray + """ + pass + + @abc.abstractmethod + def check_inside( + self, points: np.ndarray[float] + ) -> np.ndarray[bool]: # pragma: no cover + """ + Check if the points given in input are inside the geometric shape. + + :param numpy.ndarray points: A cloud of points. + :return: A bool array with same length as points, containing whether the -ith point is + inside the geometric shape or not. + :rtype: numpy.ndarray + """ + pass + + @abc.abstractmethod + def wireframe_points(self, nb_points_1=30, nb_points_2=30): # pragma: no cover + """ + Generate a wireframe to plot the geometric shape. + If a sampling of points is needed (e.g. for sphere), the wireframe is based on a grid + of shape (nb_points_1, nb_points_2). + + :param int nb_points_1: number of points sampled along the first dimension + :param int nb_points_2: number of points sampled along the second dimension + :return: Coordinate components of the wireframe + :rtype: Tuple[numpy.ndarray[numpy.ndarray[float]] + """ + pass + + +@config.node +class ShapesComposition: + """ + A collection of geometric shapes, which can be labelled to distinguish different parts of a + neuron. + """ + + shapes = config.list( + type=GeometricShape, + required=types.same_size("shapes", "labels", required=True), + hint=[{"type": "sphere", "radius": 40.0, "center": [0.0, 0.0, 0.0]}], + ) + """List of GeometricShape that make up the neuron.""" + labels = config.list( + type=types.list(), + required=types.same_size("shapes", "labels", required=True), + hint=[["soma", "dendrites", "axon"]], + ) + """List of lists of labels associated to each geometric shape.""" + voxel_size = config.attr(type=float, required=False, default=1.0) + """Dimension of the side of a voxel, used to determine how many points must be generated + to represent the geometric shape.""" + + def __init__(self, **kwargs): + # The two corners individuating the minimal bounding box. + self._mbb_min = np.array([0.0, 0.0, 0.0]) + self._mbb_max = np.array([0.0, 0.0, 0.0]) + + self.find_mbb() + + def add_shape(self, shape: GeometricShape, labels: List[str]): + """ + Add a geometric shape to the collection + + :param GeometricShape shape: A GeometricShape to add to the collection. + :param List[str] labels: A list of labels for the geometric shape to add. + """ + # Update mbb + if len(self._shapes) == 0: + self._mbb_min = np.copy(shape.mbb_min) + self._mbb_max = np.copy(shape.mbb_max) + else: + self._mbb_min = np.minimum(self._mbb_min, shape.mbb_min) + self._mbb_max = np.maximum(self._mbb_max, shape.mbb_max) + self._shapes.append(shape) + self._labels.append(labels) + + def filter_by_labels(self, labels: List[str]) -> ShapesComposition: + """ + Filter the collection of shapes, returning only the ones corresponding the given labels. + + :param List[str] labels: A list of labels. + :return: A new ShapesComposition object containing only the shapes labelled as specified. + :rtype: ShapesComposition + """ + result = ShapesComposition(dict(voxel_size=self.voxel_size, labels=[], shapes=[])) + selected_id = np.where(np.isin(labels, self._labels))[0] + result._shapes = [self._shapes[i].__copy__() for i in selected_id] + result._labels = [self._labels[i].copy() for i in selected_id] + result.mbb_min, result.mbb_max = result.find_mbb() + return result + + def translate(self, t_vec: np.ndarray[float]): + """ + Translate all the shapes in the collection by the vector t_vec. It also automatically + translate the minimal bounding box. + + :param numpy.ndarray t_vec: The displacement vector. + """ + for shape in self._shapes: + shape.translate(t_vec) + self._mbb_min += t_vec + self._mbb_max += t_vec + + def get_volumes(self) -> List[float]: + """ + Compute the volumes of all the shapes. + + :rtype: List[float] + """ + return [shape.get_volume() for shape in self._shapes] + + def get_mbb_min(self): + """ + Returns the bottom corner of the minimum bounding box containing the collection of shapes. + + :return: The bottom corner individuating the minimal bounding box of the shapes collection. + :rtype: numpy.ndarray[float] + """ + return self._mbb_min + + def get_mbb_max(self): + """ + Returns the top corner of the minimum bounding box containing the collection of shapes. + + :return: The top corner individuating the minimal bounding box of the shapes collection. + :rtype: numpy.ndarray[float] + """ + return self._mbb_max + + def find_mbb(self) -> Tuple[np.ndarray[float], np.ndarray[float]]: + """ + Compute the minimal bounding box containing the collection of shapes. + + :return: The two corners individuating the minimal bounding box of the shapes collection. + :rtype: Tuple(numpy.ndarray[float], numpy.ndarray[float]) + """ + mins = np.empty([len(self._shapes), 3]) + maxs = np.empty([len(self._shapes), 3]) + for i, shape in enumerate(self._shapes): + mins[i, :] = shape.mbb_min + maxs[i, :] = shape.mbb_max + self._mbb_min = np.min(mins, axis=0) if len(self._shapes) > 0 else np.zeros(3) + self._mbb_max = np.max(maxs, axis=0) if len(self._shapes) > 0 else np.zeros(3) + return self._mbb_min, self._mbb_max + + def compute_n_points(self) -> List[int]: + """ + Compute the number of points to generate in a point cloud, using the dimension of the voxel + specified in self._voxel_size. + + :return: The number of points to generate. + :rtype: numpy.ndarray[int] + """ + return [int(shape.get_volume() // self.voxel_size**3) for shape in self._shapes] + + def generate_point_cloud(self) -> np.ndarray[float] | None: + """ + Generate a point cloud. The number of points to generate is determined automatically using + the voxel size. + + :return: A numpy.ndarray containing the 3D points of the cloud. If there are no shapes in + the collection, it returns None. + :rtype: numpy.ndarray[float] | None + """ + if len(self._shapes) != 0: + return np.concatenate( + [ + shape.generate_point_cloud(numpts) + for shape, numpts in zip(self._shapes, self.compute_n_points()) + ] + ) + else: + return None + + def generate_wireframe( + self, + nb_points_1=30, + nb_points_2=30, + ) -> Tuple[List, List, List] | None: + """ + Generate the wireframes of a collection of shapes. + If a sampling of points is needed for certain shapes (e.g. for sphere), their wireframe + is based on a grid of shape (nb_points_1, nb_points_2). + + :param int nb_points_1: number of points sampled along the first dimension + :param int nb_points_2: number of points sampled along the second dimension + :return: The x,y,z coordinates of the wireframe of each shape. + :rtype: Tuple[List[numpy.ndarray[numpy.ndarray[float]]]] | None + """ + if len(self._shapes) != 0: + x = [] + y = [] + z = [] + for shape in self._shapes: + # For each shape, the shape of the wireframe is different, so we need to append them + # manually + xt, yt, zt = shape.wireframe_points( + nb_points_1=nb_points_1, nb_points_2=nb_points_2 + ) + x.append(xt) + y.append(yt) + z.append(zt) + return x, y, z + return None + + def inside_mbox(self, points: np.ndarray[float]) -> np.ndarray[bool]: + """ + Check if the points given in input are inside the minimal bounding box of the collection. + + :param numpy.ndarray points: An array of 3D points. + :return: A bool np.ndarray specifying whether each point of the input array is inside the + minimal bounding box of the collection. + :rtype: numpy.ndarray[bool] + """ + return inside_mbox(points, self._mbb_min, self._mbb_max) + + def inside_shapes(self, points: np.ndarray[float]) -> np.ndarray[bool] | None: + """ + Check if the points given in input are inside at least in one of the shapes of the + collection. + + :param numpy.ndarray points: An array of 3D points. + :return: A bool numpy.ndarray specifying whether each point of the input array is inside the + collection of shapes or not. + :rtype: numpy.ndarray[bool] + """ + if len(self._shapes) != 0: + is_inside = np.full(len(points), 0, dtype=bool) + for shape in self._shapes: + tmp = shape.check_mbox(points) + if np.any(tmp): + is_inside = is_inside | shape.check_inside(points) + return is_inside + else: + return None + + +@config.node +class Ellipsoid(GeometricShape, classmap_entry="ellipsoid"): + """ + An ellipsoid, described in cartesian coordinates. + """ + + origin = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0] + ) + """The coordinates of the center of the ellipsoid.""" + lambdas = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[1.0, 0.5, 2.0] + ) + """The length of the three semi-axes.""" + + @config.property(type=types.ndarray(), required=True) + def v0(self): + """The versor on which the first semi-axis lies.""" + return self._v0 + + @v0.setter + def v0(self, value): + self._v0 = np.copy(value) / np.linalg.norm(value) + + @config.property(type=types.ndarray(), required=True) + def v1(self): + """The versor on which the second semi-axis lies.""" + return self._v1 + + @v1.setter + def v1(self, value): + self._v1 = np.copy(value) / np.linalg.norm(value) + + @config.property(type=types.ndarray(), required=True) + def v2(self): + """The versor on which the third semi-axis lies.""" + return self._v2 + + @v2.setter + def v2(self, value): + self._v2 = np.copy(value) / np.linalg.norm(value) + + def find_mbb(self): + # Find the minimum bounding box, to avoid computing it every time + extrema = ( + np.array( + [ + self.lambdas[0] * self.v0, + -self.lambdas[0] * self.v0, + self.lambdas[1] * self.v1, + -self.lambdas[1] * self.v1, + self.lambdas[2] * self.v2, + -self.lambdas[2] * self.v2, + ] + ) + + self.origin + ) + mbb_min = np.min(extrema, axis=0) + mbb_max = np.max(extrema, axis=0) + return mbb_min, mbb_max + + def get_volume(self): + return np.pi * self.lambdas[0] * self.lambdas[1] * self.lambdas[2] + + def translate(self, t_vector: np.ndarray): + self.origin += t_vector + self.mbb_min += t_vector + self.mbb_max += t_vector + + def rotate(self, r_versor: np.ndarray[float], angle: float): + rot = R.from_rotvec(r_versor * angle) + self.v0 = rot.apply(self.v0) + self.v1 = rot.apply(self.v1) + self.v2 = rot.apply(self.v2) + + def surface_point(self, theta, phi): + """ + Convert polar coordinates into their 3D location on the ellipsoid surface. + + :param float|numpy.ndarray[float] theta: first polar coordinate in [0; 2*np.pi] + :param float|numpy.ndarray[float] phi: second polar coordinate in [0; np.pi] + :return: surface coordinates + :rtype: float|numpy.ndarray[float] + """ + return np.array( + [ + self.lambdas[0] * np.cos(theta) * np.sin(phi), + self.lambdas[1] * np.sin(theta) * np.sin(phi), + self.lambdas[2] * np.cos(phi), + ] + ) + + def generate_point_cloud(self, npoints: int): + sampling = uniform_surface_sampling(npoints, self.surface_point) + sampling = sampling.T * np.random.rand(npoints, 3) # sample within the shape + + # Rotate the ellipse + rmat = np.array([self.v0, self.v1, self.v2]).T + sampling = sampling.dot(rmat) + sampling = sampling + self.origin + return sampling + + def check_inside(self, points: np.ndarray[float]): + # Check if the quadratic form associated to the ellipse is less than 1 at a point + diff = points - self.origin + vmat = np.array([self.v0, self.v1, self.v2]) + diag = np.diag(1 / self.lambdas**2) + qmat = vmat.dot(diag).dot(vmat) + quad_prod = np.diagonal(diff.dot(qmat.dot(diff.T))) + + # Check if the points are inside the ellipsoid + inside_points = quad_prod < 1 + return inside_points + + def wireframe_points(self, nb_points_1=30, nb_points_2=30): + # Generate an ellipse orientated along x,y,z + x, y, z = uniform_surface_wireframe(nb_points_1, nb_points_2, self.surface_point) + # Rotate the ellipse + rmat = np.array([self.v0, self.v1, self.v2]).T + x, y, z = rotate_3d_mesh_by_rot_mat(x, y, z, rmat) + return translate_3d_mesh_by_vec(x, y, z, self.origin) + + +@config.node +class Cone(GeometricShape, classmap_entry="cone"): + """ + A cone, described in cartesian coordinates. + """ + + apex = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0] + ) + """The coordinates of the apex of the cone.""" + origin = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0] + ) + """The coordinates of the center of the cone's base.""" + radius = config.attr(type=float, required=False, default=1.0) + """The radius of the base circle.""" + + def find_mbb(self): + # Vectors identifying half of the sides of the base rectangle in xy + u = np.array([self.radius, 0, 0]) + v = np.array([0, self.radius, 0]) + + # Find the rotation angle and axis + hv = self.origin - self.apex + rot = _get_rotation_vector(hv / np.linalg.norm(hv)) + + # Rotated vectors of the box + v1 = rot.apply(u) + v2 = rot.apply(v) + v3 = self.origin - self.apex + + # Coordinates identifying the minimal bounding box + minima = np.min([v1, v2, v3, -v1, -v2], axis=0) + maxima = np.max([v1, v2, v3, -v1, -v2], axis=0) + return minima, maxima + + def get_volume(self): + h = np.linalg.norm(self.apex - self.origin) + b = np.pi * self.radius * self.radius + return b * h / 3 + + def translate(self, t_vector): + self.origin += t_vector + self.apex += t_vector + self.mbb_min += t_vector + self.mbb_max += t_vector + + def rotate(self, r_versor: np.ndarray[float], angle: float): + rot = R.from_rotvec(r_versor * angle) + self.apex = rot.apply(self.apex) + + def generate_point_cloud(self, npoints: int): + theta = np.pi * 2.0 * np.random.rand(npoints) + rand_a = np.random.rand(npoints) + rand_b = np.random.rand(npoints) + + # Height vector + hv = self.origin - self.apex + cloud = np.full((npoints, 3), 0, dtype=float) + + # Generate a cone with the apex in the origin and the origin at (0,0,1) + cloud[:, 0] = (self.radius * rand_a * np.cos(theta)) * rand_b + cloud[:, 1] = self.radius * rand_a * np.sin(theta) + cloud[:, 2] = rand_a * np.linalg.norm(hv) + + # Rotate the cone: Find the axis of rotation and compute the angle + perp, angle = _get_prod_angle_vector(hv / np.linalg.norm(hv)) + + if hv[2] < 0: + cloud[:, 2] = -cloud[:, 2] + rot = R.from_rotvec(perp * angle) + cloud = rot.apply(cloud) + + # Translate the cone + cloud = cloud + self.apex + return cloud + + def check_inside(self, points: np.ndarray[float]): + # Find the vector w of the height. + h_vector = self.origin - self.apex + height = np.linalg.norm(h_vector) + h_vector /= height + # Center the points + pts = points - self.apex + + # Rotate back to xyz + rot = _get_rotation_vector(h_vector) + rot_pts = rot.apply(pts) + + # Find the angle between the points and the apex + apex_angles = np.arccos( + np.dot((rot_pts / np.linalg.norm(rot_pts, axis=1)[..., np.newaxis]), h_vector) + ) + # Compute the cone angle + cone_angle = np.arctan(self.radius / height) + + # Select the points inside the cone + inside_points = ( + (apex_angles < cone_angle + self.epsilon) + & (rot_pts[:, 2] > np.min([self.origin[2], self.apex[2]]) - self.epsilon) + & (rot_pts[:, 2] < np.max([self.origin[2], self.apex[2]]) + self.epsilon) + ) + return inside_points + + def wireframe_points(self, nb_points_1=30, nb_points_2=30): + # Set up the grid in polar coordinates + theta = np.linspace(0, 2 * np.pi, nb_points_1) + r = np.linspace(0, self.radius, nb_points_2) + theta, r = np.meshgrid(theta, r) + + # Height vector + hv = np.array(self.origin) - np.array(self.apex) + height = np.linalg.norm(hv) + # angle = np.arctan(height/self.radius) + + # Generate a cone with the apex in the origin and the center at (0,0,1) + x = r * np.cos(theta) + y = r * np.sin(theta) + z = r * height / self.radius + + # Rotate the cone + return _rotate_by_coord(x, y, z, hv, self.apex, test_hv=True) + + +@config.node +class Cylinder(GeometricShape, classmap_entry="cylinder"): + """ + A cylinder, described in cartesian coordinates. + """ + + origin = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0] + ) + """The coordinates of the center of the bottom circle of the cylinder.""" + top_center = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 2.0, 0.0] + ) + """The coordinates of the center of the top circle of the cylinder.""" + radius = config.attr(type=float, required=False, default=1.0) + """The radius of the base circle.""" + + def find_mbb(self): + height = np.linalg.norm(self.top_center - self.origin) + # Extrema of the xyz standard cyl + extrema = [ + np.array([-self.radius, -self.radius, 0.0]), + np.array([-self.radius, self.radius, 0.0]), + np.array([self.radius, -self.radius, 0.0]), + np.array([self.radius, self.radius, 0.0]), + np.array([self.radius, self.radius, height]), + np.array([-self.radius, self.radius, height]), + np.array([self.radius, -self.radius, height]), + np.array([-self.radius, -self.radius, height]), + ] + + # Rotate the cylinder + + return _get_extrema_after_rot(extrema, self.origin, self.top_center) + + def get_volume(self): + h = np.linalg.norm(self.top_center - self.origin) + b = np.pi * self.radius * self.radius + return b * h + + def translate(self, t_vector: np.ndarray[float]): + self.origin += t_vector + self.top_center += t_vector + self.mbb_min += t_vector + self.mbb_max += t_vector + + def rotate(self, r_versor: np.ndarray[float], angle: float): + rot = R.from_rotvec(r_versor * angle) + # rotation according to bottom center + self.top_center = rot.apply(self.top_center) + + def generate_point_cloud(self, npoints: int): + # Generate an ellipse orientated along x,y,z + cloud = np.full((npoints, 3), 0, dtype=float) + theta = np.pi * 2.0 * np.random.rand(npoints) + rand = np.random.rand(npoints, 3) + height = np.linalg.norm(self.top_center - self.origin) + + # Generate an ellipsoid centered at the origin, with the semiaxes on x,y,z + cloud[:, 0] = self.radius * np.cos(theta) + cloud[:, 1] = self.radius * np.sin(theta) + cloud[:, 2] = height + cloud = cloud * rand + + # Rotate the cylinder + hv = (self.top_center - self.origin) / height + rot = _get_rotation_vector(hv / np.linalg.norm(hv)) + cloud = rot.apply(cloud) + cloud = cloud + self.origin + return cloud + + def check_inside(self, points: np.ndarray[float]): + # Translate back to origin + pts = points - self.origin + + # Rotate back to xyz + height = np.linalg.norm(self.top_center - self.origin) + rot = _get_rotation_vector( + (self.top_center - self.origin) / height, positive_angle=False + ) + rot_pts = rot.apply(pts) + # Check for intersections + inside_points = ( + (rot_pts[:, 2] < height + self.epsilon) + & (rot_pts[:, 2] > -self.epsilon) + & ( + rot_pts[:, 0] * rot_pts[:, 0] + rot_pts[:, 1] * rot_pts[:, 1] + < self.radius**2 + self.epsilon + ) + ) + return inside_points + + def wireframe_points(self, nb_points_1=30, nb_points_2=30): + # Set up the grid in polar coordinates + theta = np.linspace(0, 2 * np.pi, nb_points_1) + + # Height vector + hv = np.array(self.origin) - np.array(self.top_center) + height = np.linalg.norm(hv) + + h = np.linspace(0, height, nb_points_2) + theta, h = np.meshgrid(theta, h) + + x = self.radius * np.cos(theta) + y = self.radius * np.sin(theta) + z = h + + # Rotate the cylinder + hv = (self.top_center - self.origin) / height + return _rotate_by_coord(x, y, z, hv, self.origin) + + +@config.node +class Sphere(GeometricShape, classmap_entry="sphere"): + """ + A sphere, described in cartesian coordinates. + """ + + origin = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0] + ) + """The coordinates of the center of the sphere.""" + radius = config.attr(type=float, required=False, default=1.0) + """The radius of the sphere.""" + + def find_mbb(self): + # Find the minimum bounding box, to avoid computing it every time + mbb_min = np.array([-self.radius, -self.radius, -self.radius]) + self.origin + mbb_max = np.array([self.radius, self.radius, self.radius]) + self.origin + return mbb_min, mbb_max + + def get_volume(self): + return np.pi * 4.0 / 3.0 * np.power(self.radius, 3) + + def translate(self, t_vector: np.ndarray[float]): + self.origin += t_vector + self.mbb_min += t_vector + self.mbb_max += t_vector + + def rotate(self, r_versor: np.ndarray[float], angle: float): # pragma: no cover + # It's a sphere, it's invariant under rotation! + pass + + def surface_function(self, theta, phi): + """ + Convert polar coordinates into their 3D location on the sphere surface. + + :param float|numpy.ndarray[float] theta: first polar coordinate in [0; 2*np.pi] + :param float|numpy.ndarray[float] phi: second polar coordinate in [0; np.pi] + :return: surface coordinates + :rtype: float|numpy.ndarray[float] + """ + return np.array( + [ + self.radius * np.cos(theta) * np.sin(phi), + self.radius * np.sin(theta) * np.sin(phi), + self.radius * np.cos(phi), + ] + ) + + def generate_point_cloud(self, npoints: int): + # Generate a sphere centered at the origin. + cloud = uniform_surface_sampling(npoints, self.surface_function) + cloud = cloud.T * np.random.rand(npoints, 3) # sample within the shape + + cloud = cloud + self.origin + + return cloud + + def check_inside(self, points: np.ndarray[float]): + # Translate the points, bringing the origin to the center of the sphere, + # then check the inequality defining the sphere + pts_centered = points - self.origin + lhs = np.linalg.norm(pts_centered, axis=1) + inside_points = lhs < self.radius + self.epsilon + return inside_points + + def wireframe_points(self, nb_points_1=30, nb_points_2=30): + x, y, z = uniform_surface_wireframe( + nb_points_1, nb_points_2, self.surface_function + ) + return translate_3d_mesh_by_vec(x, y, z, self.origin) + + +@config.node +class Cuboid(GeometricShape, classmap_entry="cuboid"): + """ + A rectangular parallelepiped, described in cartesian coordinates. + """ + + origin = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0] + ) + """The coordinates of the center of the barycenter of the bottom rectangle.""" + top_center = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0] + ) + """The coordinates of the center of the barycenter of the top rectangle.""" + side_length_1 = config.attr(type=float, required=False, default=1.0) + """Length of one side of the base rectangle.""" + side_length_2 = config.attr(type=float, required=False, default=1.0) + """Length of the other side of the base rectangle.""" + + def find_mbb(self): + # Extrema of the cuboid centered at the origin + extrema = [ + np.array([-self.side_length_1 / 2.0, -self.side_length_2 / 2.0, 0.0]), + np.array([self.side_length_1 / 2.0, self.side_length_2 / 2.0, 0.0]), + np.array([-self.side_length_1 / 2.0, self.side_length_2 / 2.0, 0.0]), + np.array([self.side_length_1 / 2.0, -self.side_length_2 / 2.0, 0.0]), + np.array( + [ + self.side_length_1 / 2.0 + self.top_center[0], + self.side_length_2 / 2.0 + self.top_center[1], + self.top_center[2], + ] + ), + np.array( + [ + -self.side_length_1 / 2.0 + self.top_center[0], + self.side_length_2 / 2.0 + self.top_center[1], + self.top_center[2], + ] + ), + np.array( + [ + -self.side_length_1 / 2.0 + self.top_center[0], + -self.side_length_2 / 2.0 + self.top_center[1], + self.top_center[2], + ] + ), + np.array( + [ + self.side_length_1 / 2.0 + self.top_center[0], + -self.side_length_2 / 2.0 + self.top_center[1], + self.top_center[2], + ] + ), + ] + + # Rotate the cuboid + return _get_extrema_after_rot(extrema, self.origin, self.top_center) + + def get_volume(self): + h = np.linalg.norm(self.top_center - self.origin) + return h * self.side_length_1 * self.side_length_2 + + def translate(self, t_vector: np.ndarray[float]): + self.origin += t_vector + self.top_center += t_vector + self.mbb_min += t_vector + self.mbb_max += t_vector + + def rotate(self, r_versor: np.ndarray[float], angle: float): + rot = R.from_rotvec(r_versor * angle) + self.top_center = rot.apply(self.top_center) + + def generate_point_cloud(self, npoints: int): + # Generate a unit cuboid whose base rectangle has the barycenter in the origin + rand = np.random.rand(npoints, 3) + rand[:, 0] = rand[:, 0] - 0.5 + rand[:, 1] = rand[:, 1] - 0.5 + + # Scale the sides of the cuboid + height = np.linalg.norm(self.top_center - self.origin) + rand[:, 0] = rand[:, 0] * self.side_length_1 / 2.0 + rand[:, 1] = rand[:, 1] * self.side_length_2 / 2.0 + rand[:, 2] = rand[:, 2] * height + + # Rotate the cuboid + rot = _get_rotation_vector((self.top_center - self.origin) / height) + cloud = rot.apply(rand) + + # Translate the cuboid + cloud = cloud + self.origin + return cloud + + def check_inside(self, points: np.ndarray[float]): + # Translate back to origin + pts = points - self.origin + + # Rotate back to xyz + height = np.linalg.norm(self.top_center - self.origin) + rot = _get_rotation_vector( + (self.top_center - self.origin) / height, positive_angle=False + ) + rot_pts = rot.apply(pts) + + # Check for intersections + inside_points = ( + (rot_pts[:, 2] < height) + & (rot_pts[:, 2] > 0.0) + & (rot_pts[:, 0] < self.side_length_1) + & (rot_pts[:, 0] > -self.side_length_1) + & (rot_pts[:, 1] < self.side_length_2) + & (rot_pts[:, 1] > -self.side_length_2) + ) + return inside_points + + def wireframe_points(self, **kwargs): + a = self.side_length_1 / 2.0 + b = self.side_length_2 / 2.0 + c = np.linalg.norm(self.top_center - self.origin) + + x = np.array( + [ + [-a, a, a, -a], # x coordinate of points in bottom surface + [-a, a, a, -a], # x coordinate of points in upper surface + [-a, a, -a, a], # x coordinate of points in outside surface + [-a, a, -a, a], + ] + ) # x coordinate of points in inside surface + y = np.array( + [ + [-b, -b, b, b], # y coordinate of points in bottom surface + [-b, -b, b, b], # y coordinate of points in upper surface + [-b, -b, -b, -b], # y coordinate of points in outside surface + [b, b, b, b], + ] + ) # y coordinate of points in inside surface + z = np.array( + [ + [0.0, 0.0, 0.0, 0.0], # z coordinate of points in bottom surface + [c, c, c, c], # z coordinate of points in upper surface + [0.0, 0.0, c, c], # z coordinate of points in outside surface + [0.0, 0.0, c, c], + ] + ) # z coordinate of points in inside surface + + # Rotate the cuboid + hv = (self.top_center - self.origin) / c + return _rotate_by_coord(x, y, z, hv, self.origin) + + +@config.node +class Parallelepiped(GeometricShape, classmap_entry="parallelepiped"): + """ + A generic parallelepiped, described by the vectors (following the right-hand orientation) of the + sides in cartesian coordinates + """ + + origin = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 0.0] + ) + """The coordinates of the left-bottom edge.""" + side_vector_1 = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[1.0, 0.0, 0.0] + ) + """The first vector identifying the parallelepiped (using the right-hand orientation: the + thumb).""" + side_vector_2 = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 1.0, 0.0] + ) + """The second vector identifying the parallelepiped (using the right-hand orientation: the + index).""" + side_vector_3 = config.attr( + type=types.ndarray(dtype=float), required=True, hint=[0.0, 0.0, 1.0] + ) + """The third vector identifying the parallelepiped (using the right-hand orientation: the + middle finger).""" + + def find_mbb(self): + extrema = np.vstack( + [ + np.array([0.0, 0.0, 0.0]), + self.side_vector_1 + self.side_vector_2 + self.side_vector_3, + ] + ) + maxima = np.max(extrema, axis=0) + self.origin + minima = np.min(extrema, axis=0) + self.origin + return minima, maxima + + def get_volume(self): + vol = np.dot(self.side_vector_3, np.cross(self.side_vector_1, self.side_vector_2)) + return vol + + def translate(self, t_vector: np.ndarray[float]): + self.origin += t_vector + self.mbb_min += t_vector + self.mbb_max += t_vector + + def rotate(self, r_versor: np.ndarray[float], angle: float): + rot = R.from_rotvec(r_versor * angle) + # self.center = rot.apply(self.center) + self.side_vector_1 = rot.apply(self.side_vector_1) + self.side_vector_2 = rot.apply(self.side_vector_2) + self.side_vector_3 = rot.apply(self.side_vector_3) + + def generate_point_cloud(self, npoints: int): + # Generate a linear combination of points in the volume + cloud = np.full((npoints, 3), 0, dtype=float) + rand = np.random.rand(npoints, 3) + for i in range(npoints): + cloud[i] = ( + rand[i, 0] * self.side_vector_1 + + rand[i, 1] * self.side_vector_2 + + rand[i, 2] * self.side_vector_3 + ) + cloud += self.origin + return cloud + + def check_inside(self, points: np.ndarray[float]): + # Translate back to origin + pts = points - self.origin + + # Rotate back to xyz + height = np.linalg.norm(self.side_vector_3) + rot = _get_rotation_vector(hv=self.side_vector_3 / height, positive_angle=True) + rot_pts = rot.apply(pts) + + # Compute the Fourier components wrt to the vectors identifying the parallelepiped + + v1_norm = np.linalg.norm(self.side_vector_1) + comp1 = rot_pts.dot(self.side_vector_1) / v1_norm + v2_norm = np.linalg.norm(self.side_vector_2) + comp2 = rot_pts.dot(self.side_vector_2) / v2_norm + v3_norm = np.linalg.norm(self.side_vector_3) + comp3 = rot_pts.dot(self.side_vector_3) / v3_norm + + # The points are inside the parallelepiped if and only if all the Fourier components + # are between 0 and the norm of sides of the parallelepiped + inside_points = ( + (comp1 > 0.0) + & (comp1 < v1_norm) + & (comp2 > 0.0) + & (comp2 < v2_norm) + & (comp3 > 0.0) + & (comp3 < v3_norm) + ) + return inside_points + + def wireframe_points(self, **kwargs): + va = self.side_vector_1 + vb = self.side_vector_2 + vc = self.side_vector_3 + + a = va + b = va + vb + c = vb + d = np.array([0.0, 0.0, 0.0]) + e = va + vc + f = va + vb + vc + g = vb + vc + h = vc + + x = np.array( + [ + [a[0], b[0], c[0], d[0]], + [e[0], f[0], g[0], h[0]], + [a[0], b[0], f[0], e[0]], + [d[0], c[0], g[0], h[0]], + ] + ) + y = np.array( + [ + [a[1], b[1], c[1], d[1]], + [e[1], f[1], g[1], h[1]], + [a[1], b[1], f[1], e[1]], + [d[1], c[1], g[1], h[1]], + ] + ) + z = np.array( + [ + [a[2], b[2], c[2], d[2]], + [e[2], f[2], g[2], h[2]], + [a[2], b[2], f[2], e[2]], + [d[2], c[2], g[2], h[2]], + ] + ) + + return x + self.origin[0], y + self.origin[1], z + self.origin[2] diff --git a/bsb/connectivity/geometric/morphology_shape_intersection.py b/bsb/connectivity/geometric/morphology_shape_intersection.py new file mode 100644 index 00000000..97e93146 --- /dev/null +++ b/bsb/connectivity/geometric/morphology_shape_intersection.py @@ -0,0 +1,142 @@ +import numpy as np + +from ... import config +from ...config import types +from .. import ConnectionStrategy +from .shape_morphology_intersection import _create_geometric_conn_arrays +from .shape_shape_intersection import ShapeHemitype + + +def overlap_boxes(box1_min, box1_max, box2_min, box2_max): + """ + Check if two minimal bounding box are overlapping. + + :param numpy.ndarray box1_min: 3D point representing the lowest coordinate of the + minimal bounding box. + :param numpy.ndarray box1_max: 3D point representing the highest coordinate of the + minimal bounding box. + :param numpy.ndarray box2_min: 3D point representing the lowest coordinate of the + minimal bounding box. + :param numpy.ndarray box2_max: 3D point representing the highest coordinate of the + minimal bounding box. + """ + return np.all((box1_max >= box2_min) & (box2_max >= box1_min)) + + +@config.node +class MorphologyToShapeIntersection(ConnectionStrategy): + postsynaptic = config.attr(type=ShapeHemitype, required=True) + affinity = config.attr(type=types.fraction(), required=True, hint=0.1) + """Ratio of apositions to keep over the total number of contact points""" + pruning_ratio = config.attr(type=types.fraction(), required=True, hint=0.1) + """Ratio of conections to keep over the total number of apositions""" + + def get_region_of_interest(self, chunk): + lpre, upre = self.presynaptic._get_rect_ext(tuple(chunk.dimensions)) + post_chunks = self.postsynaptic.get_all_chunks() + tree = self.postsynaptic._get_shape_boxtree( + post_chunks, + ) + pre_mbb = [ + np.concatenate( + [ + (lpre + chunk) * chunk.dimensions, + np.maximum((upre + chunk), (lpre + chunk) + 1) * chunk.dimensions, + ] + ) + ] + return [post_chunks[i] for i in tree.query(pre_mbb, unique=True)] + + def connect(self, pre, post): + for pre_ps in pre.placement: + for post_ps in post.placement: + self._connect_type(pre_ps, post_ps) + + def _connect_type(self, pre_ps, post_ps): + pre_pos = pre_ps.load_positions() + post_pos = post_ps.load_positions() + + post_shapes = self.postsynaptic.shapes_composition.__copy__() + + to_connect_pre = np.empty([0, 3], dtype=int) + to_connect_post = np.empty([0, 3], dtype=int) + + morpho_set = pre_ps.load_morphologies() + pre_morphos = morpho_set.iter_morphologies(cache=True, hard_cache=True) + + for pre_id, (pre_coord, morpho) in enumerate(zip(pre_pos, pre_morphos)): + # Get the branches + branches = morpho.get_branches() + + # Build ids array from the morphology + pre_points_ids, pre_morpho_coord = _create_geometric_conn_arrays( + branches, pre_id, pre_coord + ) + pre_min_mbb, pre_max_mbb = morpho.bounds + pre_min_mbb += pre_coord + pre_max_mbb += pre_coord + + tmp_pre_selection = np.full( + [len(post_pos) * int(len(pre_morpho_coord) * self.affinity), 3], + -1, + dtype=int, + ) + tmp_post_selection = np.full( + [len(post_pos) * int(len(pre_morpho_coord) * self.affinity), 3], + -1, + dtype=int, + ) + ptr = 0 + + for post_id, post_coord in enumerate(post_pos): + post_shapes.translate(post_coord) + if overlap_boxes( + post_shapes.get_mbb_min(), + post_shapes.get_mbb_max(), + pre_min_mbb, + pre_max_mbb, + ): + mbb_check = post_shapes.inside_mbox(pre_morpho_coord) + if np.any(mbb_check): + inside_pts = post_shapes.inside_shapes( + pre_morpho_coord[mbb_check] + ) + if np.any(inside_pts): + local_selection = (pre_points_ids[mbb_check])[inside_pts] + if self.affinity < 1 and len(local_selection) > 0: + nb_sources = np.max( + [ + 1, + int( + np.floor(self.affinity * len(local_selection)) + ), + ] + ) + chosen_targets = np.random.choice( + local_selection.shape[0], nb_sources + ) + local_selection = local_selection[chosen_targets, :] + + selected_count = len(local_selection) + if selected_count > 0: + tmp_pre_selection[ptr : ptr + selected_count, 0] = pre_id + tmp_post_selection[ptr : ptr + selected_count, 0] = ( + post_id + ) + ptr += selected_count + + post_shapes.translate(-post_coord) + + to_connect_pre = np.vstack([to_connect_pre, tmp_pre_selection[:ptr]]) + to_connect_post = np.vstack([to_connect_post, tmp_post_selection[:ptr]]) + + if self.pruning_ratio < 1 and len(to_connect_pre) > 0: + ids_to_select = np.random.choice( + len(to_connect_pre), + int(np.floor(self.pruning_ratio * len(to_connect_pre))), + replace=False, + ) + to_connect_pre = to_connect_pre[ids_to_select] + to_connect_post = to_connect_post[ids_to_select] + + self.connect_cells(pre_ps, post_ps, to_connect_pre, to_connect_post) diff --git a/bsb/connectivity/geometric/shape_morphology_intersection.py b/bsb/connectivity/geometric/shape_morphology_intersection.py new file mode 100644 index 00000000..6fb2b383 --- /dev/null +++ b/bsb/connectivity/geometric/shape_morphology_intersection.py @@ -0,0 +1,111 @@ +import numpy as np + +from ... import config +from ...config import types +from .. import ConnectionStrategy +from .shape_shape_intersection import ShapeHemitype + + +def _create_geometric_conn_arrays(branches, ids, coord): + morpho_points = 0 + for b in branches: + morpho_points += len(b.points) + points_ids = np.empty([morpho_points, 3], dtype=int) + morpho_coord = np.empty([morpho_points, 3], dtype=float) + local_ptr = 0 + for i, b in enumerate(branches): + points_ids[local_ptr : local_ptr + len(b.points), 0] = ids + points_ids[local_ptr : local_ptr + len(b.points), 1] = i + points_ids[local_ptr : local_ptr + len(b.points), 2] = np.arange(len(b.points)) + tmp = b.points + coord + morpho_coord[local_ptr : local_ptr + len(b.points), :] = tmp + local_ptr += len(b.points) + return points_ids, morpho_coord + + +@config.node +class ShapeToMorphologyIntersection(ConnectionStrategy): + presynaptic = config.attr(type=ShapeHemitype, required=True) + affinity = config.attr(type=types.fraction(), required=True, hint=0.1) + """Ratio of apositions to keep over the total number of contact points""" + pruning_ratio = config.attr(type=types.fraction(), required=True, hint=0.1) + """Ratio of conections to keep over the total number of apositions""" + + def get_region_of_interest(self, chunk): + lpost, upost = self.postsynaptic._get_rect_ext(tuple(chunk.dimensions)) + pre_chunks = self.presynaptic.get_all_chunks() + tree = self.presynaptic._get_shape_boxtree( + pre_chunks, + ) + post_mbb = [ + np.concatenate( + [ + (lpost + chunk) * chunk.dimensions, + np.maximum((upost + chunk), (lpost + chunk) + 1) * chunk.dimensions, + ] + ) + ] + + return [pre_chunks[i] for i in tree.query(post_mbb, unique=True)] + + def connect(self, pre, post): + for pre_ps in pre.placement: + for post_ps in post.placement: + self._connect_type(pre_ps.cell_type, pre_ps, post_ps.cell_type, post_ps) + + def _connect_type(self, pre_ct, pre_ps, post_ct, post_ps): + pre_pos = pre_ps.load_positions() + post_pos = post_ps.load_positions() + + pre_shapes = self.presynaptic.shapes_composition.__copy__() + + to_connect_pre = np.empty([0, 3], dtype=int) + to_connect_post = np.empty([0, 3], dtype=int) + + morpho_set = post_ps.load_morphologies() + post_morphos = morpho_set.iter_morphologies(cache=True, hard_cache=True) + + for post_id, (post_coord, morpho) in enumerate(zip(post_pos, post_morphos)): + # Get the branches + branches = morpho.get_branches() + + # Build ids array from the morphology + post_points_ids, post_morpho_coord = _create_geometric_conn_arrays( + branches, post_id, post_coord + ) + + for pre_id, pre_coord in enumerate(pre_pos): + pre_shapes.translate(pre_coord) + mbb_check = pre_shapes.inside_mbox(post_morpho_coord) + + if np.any(mbb_check): + inside_pts = pre_shapes.inside_shapes(post_morpho_coord[mbb_check]) + # Find the morpho points inside the postsyn geometric shapes + if np.any(inside_pts): + local_selection = (post_points_ids[mbb_check])[inside_pts] + if self.affinity < 1.0 and len(local_selection) > 0: + nb_targets = np.max( + [1, int(np.floor(self.affinity * len(local_selection)))] + ) + chosen_targets = np.random.choice( + local_selection.shape[0], nb_targets + ) + local_selection = local_selection[chosen_targets, :] + selected_count = len(local_selection) + if selected_count > 0: + to_connect_post = np.vstack( + [to_connect_post, local_selection] + ) + pre_tmp = np.full([selected_count, 3], -1, dtype=int) + pre_tmp[:, 0] = pre_id + to_connect_pre = np.vstack([to_connect_pre, pre_tmp]) + pre_shapes.translate(-pre_coord) + if self.pruning_ratio < 1 and len(to_connect_pre) > 0: + ids_to_select = np.random.choice( + len(to_connect_pre), + int(np.floor(self.pruning_ratio * len(to_connect_pre))), + replace=False, + ) + to_connect_pre = to_connect_pre[ids_to_select] + to_connect_post = to_connect_post[ids_to_select] + self.connect_cells(pre_ps, post_ps, to_connect_pre, to_connect_post) diff --git a/bsb/connectivity/geometric/shape_shape_intersection.py b/bsb/connectivity/geometric/shape_shape_intersection.py new file mode 100644 index 00000000..59426284 --- /dev/null +++ b/bsb/connectivity/geometric/shape_shape_intersection.py @@ -0,0 +1,159 @@ +import numpy as np + +from ... import config +from ...config import types +from ...trees import BoxTree +from .. import ConnectionStrategy +from ..strategy import Hemitype +from .geometric_shapes import ShapesComposition + + +@config.node +class ShapeHemitype(Hemitype): + """ + Class representing a population of cells to connect with a ConnectionStrategy. + These cells' morphology is implemented as a ShapesComposition. + """ + + shapes_composition = config.attr(type=ShapesComposition, required=True) + """ + Composite shape representing the Hemitype. + """ + + def get_mbb(self, chunks, chunk_dimension): + """ + Get the list of minimal bounding box containing all cells in the `ShapeHemitype`. + + :param chunks: List of chunks containing the cell types + (see bsb.connectivity.strategy.Hemitype.get_all_chunks) + :type chunks: List[bsb.storage._chunks.Chunk] + :param chunk_dimension: Size of a chunk + :type chunk_dimension: float + :return: List of bounding boxes in the form [min_x, min_y, min_z, max_x, max_y, max_z] + for each chunk containing cells. + :rtype: List[numpy.ndarray[float, float, float, float, float, float]] + """ + return [ + np.concatenate( + [ + self.shapes_composition.get_mbb_min() + + np.array(idx_chunk) * chunk_dimension, + self.shapes_composition.get_mbb_max() + + np.array(idx_chunk) * chunk_dimension, + ] + ) + for idx_chunk in chunks + ] + + def _get_shape_boxtree(self, chunks): + mbbs = self.get_mbb(chunks, chunks[0].dimensions) + return BoxTree(mbbs) + + +@config.node +class ShapeToShapeIntersection(ConnectionStrategy): + presynaptic = config.attr(type=ShapeHemitype, required=True) + postsynaptic = config.attr(type=ShapeHemitype, required=True) + affinity = config.attr(type=types.fraction(), required=True, hint=0.1) + """Ratio of apositions to keep over the total number of contact points""" + pruning_ratio = config.attr(type=types.fraction(), required=True, hint=0.1) + """Ratio of conections to keep over the total number of apositions""" + + def get_region_of_interest(self, chunk): + # Filter postsyn chunks that overlap the presyn chunk. + post_chunks = self.postsynaptic.get_all_chunks() + tree = self.postsynaptic._get_shape_boxtree( + post_chunks, + ) + pre_mbb = self.presynaptic.get_mbb( + self.presynaptic.get_all_chunks(), chunk.dimensions + ) + return [post_chunks[i] for i in tree.query(pre_mbb, unique=True)] + + def connect(self, pre, post): + for pre_ps in pre.placement: + for post_ps in post.placement: + self._connect_type(pre_ps.cell_type, pre_ps, post_ps.cell_type, post_ps) + + def _connect_type(self, pre_ct, pre_ps, post_ct, post_ps): + pre_pos = pre_ps.load_positions() + post_pos = post_ps.load_positions() + + pre_shapes_cache = self.presynaptic.shapes_composition.__copy__() + post_shapes_cache = self.postsynaptic.shapes_composition.__copy__() + + to_connect_pre = np.empty([0, 3], dtype=int) + to_connect_post = np.empty([0, 3], dtype=int) + + for pre_id, pre_coord in enumerate(pre_pos): + # Generate pre point cloud + pre_shapes_cache.translate(pre_coord) + pre_point_cloud = pre_shapes_cache.generate_point_cloud() + + def find_mbb(coords): + maxima = np.max(coords, axis=0) + minima = np.min(coords, axis=0) + return minima, maxima + + def BoxesOverlap(box1min, box1max, box2min, box2max): + return np.all((box1max >= box2min) & (box2max >= box1min)) + + pre_mbb_min = pre_shapes_cache.get_mbb_min() + pre_mbb_max = pre_shapes_cache.get_mbb_max() + + points_per_cloud = int(len(pre_point_cloud) * self.affinity) + tmp_pre_selection = np.full( + [len(post_pos) * int(points_per_cloud), 3], -1, dtype=int + ) + tmp_post_selection = np.full( + [len(post_pos) * int(points_per_cloud), 3], -1, dtype=int + ) + ptr = 0 + for post_id, post_coord in enumerate(post_pos): + post_shapes_cache.translate(post_coord) + post_mbb_min = post_shapes_cache.get_mbb_min() + post_mbb_max = post_shapes_cache.get_mbb_max() + boxes_overlap = BoxesOverlap( + post_mbb_min, post_mbb_max, pre_mbb_min, pre_mbb_max + ) + if boxes_overlap: + # Compare pre and post mbbs + inside_mbbox = post_shapes_cache.inside_mbox(pre_point_cloud) + if np.any(inside_mbbox): + inside_pts = post_shapes_cache.inside_shapes(pre_point_cloud) + selected = pre_point_cloud[inside_pts] + + def sizemod(q, aff): + ln = len(q) + return int( + np.floor(ln * aff) + (np.random.rand() < ((ln * aff) % 1)) + ) + + selected = selected[ + np.random.randint( + len(selected), size=sizemod(selected, self.affinity) + ), + :, + ] + n_synapses = len(selected) + if n_synapses > 0: + tmp_pre_selection[ptr : ptr + n_synapses, 0] = pre_id + tmp_post_selection[ptr : ptr + n_synapses, 0] = post_id + ptr += n_synapses + post_shapes_cache.translate(-post_coord) + if ptr > 0: + to_connect_pre = np.vstack([to_connect_pre, tmp_pre_selection[:ptr]]) + to_connect_post = np.vstack([to_connect_post, tmp_post_selection[:ptr]]) + + pre_shapes_cache.translate(-pre_coord) + + if self.pruning_ratio < 1 and len(to_connect_pre) > 0: + ids_to_select = np.random.choice( + len(to_connect_pre), + int(np.floor(self.pruning_ratio * len(to_connect_pre))), + replace=False, + ) + to_connect_pre = to_connect_pre[ids_to_select] + to_connect_post = to_connect_post[ids_to_select] + + self.connect_cells(pre_ps, post_ps, to_connect_pre, to_connect_post) diff --git a/bsb/connectivity/strategy.py b/bsb/connectivity/strategy.py index 17163cff..d4de5ccf 100644 --- a/bsb/connectivity/strategy.py +++ b/bsb/connectivity/strategy.py @@ -1,7 +1,10 @@ import abc import typing +from functools import cache from itertools import chain +import numpy as np + from .. import config from .._util import ichain, obj_str_insert from ..config import refs, types @@ -45,6 +48,38 @@ class Hemitype: take too much disk space or time otherwise. """ + def get_all_chunks(self): + """ + Get the list of all chunks where the cell types were placed + + :return: List of Chunks + :rtype: List[bsb.storage._chunks.Chunk] + """ + return [ + c for ct in self.cell_types for c in ct.get_placement_set().get_all_chunks() + ] + + @cache + def _get_rect_ext(self, chunk_size): + # Returns the lower and upper boundary Chunk of the box containing the cell type population, + # based on the cell type's morphology if it exists. + # This box is centered on the Chunk [0., 0., 0.]. + # If no morphologies are associated to the cell types then the bounding box size is 0. + types = self.cell_types + loader = self.morpho_loader + ps_list = [ct.get_placement_set() for ct in types] + ms_list = [loader(ps) for ps in ps_list] + if not sum(map(len, ms_list)): + # No cells placed, return smallest possible RoI. + return [np.array([0, 0, 0]), np.array([0, 0, 0])] + metas = list(chain.from_iterable(ms.iter_meta(unique=True) for ms in ms_list)) + # TODO: Combine morphology extension information with PS rotation information. + # Get the chunk coordinates of the boundaries of this chunk convoluted with the + # extension of the intersecting morphologies. + lbounds = np.min([m["ldc"] for m in metas], axis=0) // chunk_size + ubounds = np.max([m["mdc"] for m in metas], axis=0) // chunk_size + return lbounds, ubounds + class HemitypeCollection: def __init__(self, hemitype, roi): @@ -158,6 +193,15 @@ def connect_cells(self, pre_set, post_set, src_locs, dest_locs, tag=None): cs.connect(pre_set, post_set, src_locs, dest_locs) def get_region_of_interest(self, chunk): + """ + Returns the list of chunks containing the potential postsynaptic neurons, based on a + chunk containing the presynaptic neurons. + + :param chunk: Presynaptic chunk + :type chunk: bsb.storage._chunks.Chunk + :returns: List of postsynaptic chunks + :rtype: List[bsb.storage._chunks.Chunk] + """ pass def queue(self, pool: "JobPool"): diff --git a/bsb/core.py b/bsb/core.py index ac40a4a1..4abd60bb 100644 --- a/bsb/core.py +++ b/bsb/core.py @@ -27,7 +27,7 @@ if typing.TYPE_CHECKING: from .cell_types import CellType from .config._config import NetworkNode as Network - from .postprocessing import AfterPlacementHook + from .postprocessing import AfterConnectivityHook, AfterPlacementHook from .simulation.simulation import Simulation from .storage.interfaces import ( ConnectivitySet, @@ -87,7 +87,7 @@ def _get_linked_config(storage=None): path = cfg._meta.get("path", None) if path and os.path.exists(path): with open(path, "r") as f: - cfg = bsb.config.parse_configuration_file(f) + cfg = bsb.config.parse_configuration_file(f, path=path) return cfg else: return None @@ -113,7 +113,7 @@ class Scaffold: placement: typing.Dict[str, "PlacementStrategy"] after_placement: typing.Dict[str, "AfterPlacementHook"] connectivity: typing.Dict[str, "ConnectionStrategy"] - after_connectivity: typing.Dict[str, "AfterPlacementHook"] + after_connectivity: typing.Dict[str, "AfterConnectivityHook"] simulations: typing.Dict[str, "Simulation"] def __init__(self, config=None, storage=None, clear=False, comm=None): diff --git a/bsb/exceptions.py b/bsb/exceptions.py index f06ffac8..5715a7ae 100644 --- a/bsb/exceptions.py +++ b/bsb/exceptions.py @@ -107,7 +107,10 @@ ), DataNotProvidedError=_e(), PluginError=_e("plugin"), - ParserError=_e(), + ParserError=_e( + FileImportError=_e(), + FileReferenceError=_e(), + ), ClassError=_e(), ), ) @@ -183,6 +186,8 @@ class PackageRequirementWarning(ScaffoldWarning): "EmptySelectionError", "EmptyVoxelSetError", "ExternalSourceError", + "FileImportError", + "FileReferenceError", "GatewayError", "IncompleteExternalMapError", "IncompleteMorphologyError", diff --git a/bsb/morphologies/__init__.py b/bsb/morphologies/__init__.py index 4835f56f..b5872e9d 100644 --- a/bsb/morphologies/__init__.py +++ b/bsb/morphologies/__init__.py @@ -31,12 +31,6 @@ from ..voxels import VoxelSet -def parse_morphology_file(file, **kwargs): - from .parsers import parse_morphology_file - - return parse_morphology_file(file, **kwargs) - - class MorphologySet: """ Associates a set of :class:`StoredMorphologies @@ -303,7 +297,7 @@ def branch_iter(branch): class SubTree: """ - Collection of branches, not necesarily all connected. + Collection of branches, not necessarily all connected. """ def __init__(self, branches, sanitize=True): @@ -537,8 +531,8 @@ def rotate(self, rotation, center=None): """ Point rotation - :param rot: Scipy rotation - :type: Union[scipy.spatial.transform.Rotation, List[float,float,float]] + :param rotation: Scipy rotation + :type rotation: Union[scipy.spatial.transform.Rotation, List[float,float,float]] :param center: rotation offset point. :type center: numpy.ndarray """ @@ -944,6 +938,26 @@ def as_filtered(self, labels=None): # Construct and return the morphology return self.__class__(roots, meta=self.meta.copy()) + def swap_axes(self, axis1: int, axis2: int): + """ + Interchange two axes of a morphology points. + + :param int axis1: index of the first axis to exchange + :param int axis2: index of the second axis to exchange + :return: the modified morphology + :rtype: bsb.morphologies.Morphology + """ + if not 0 <= axis1 < 3 or not 0 <= axis2 < 3: + raise ValueError( + f"Axes values should be in [0, 1, 2], {axis1}, {axis2} given." + ) + for b in self.branches: + old_column = np.copy(b.points[:, axis1]) + b.points[:, axis1] = b.points[:, axis2] + b.points[:, axis2] = old_column + + return self + def simplify(self, *args, optimize=True, **kwargs): super().simplify_branches(*args, **kwargs) if optimize: @@ -984,7 +998,7 @@ def to_graph_array(self): def _copy_api(cls, wrap=lambda self: self): - # Wraps functions so they are called with `self` wrapped in `wrap` + # Wraps functions, so they are called with `self` wrapped in `wrap` def make_wrapper(f): @functools.wraps(f) def wrapper(self, *args, **kwargs): @@ -1023,7 +1037,7 @@ def __init__(self, points, radii, labels=None, properties=None, children=None): :param radii: Array of radii associated to each point :type radii: list | numpy.ndarray :param labels: Array of labels to associate to each point - :type labels: EncodedLabels | List[str] | set | numpy.ndarray + :type labels: List[str] | set | numpy.ndarray :param properties: dictionary of per-point data to store in the branch :type properties: dict :param children: list of child branches to attach to the branch @@ -1710,5 +1724,4 @@ def _morpho_to_swc(morpho): "RotationSet", "SubTree", "branch_iter", - "parse_morphology_file", ] diff --git a/bsb/services/pool.py b/bsb/services/pool.py index 70bbc087..c48f41ac 100644 --- a/bsb/services/pool.py +++ b/bsb/services/pool.py @@ -412,7 +412,7 @@ def _dep_completed(self, dep): self._enqueue(self._pool) def _enqueue(self, pool): - if not self._deps and self._status is not JobStatus.CANCELLED: + if not self._deps and self._status is JobStatus.PENDING: # Go ahead and submit ourselves to the pool, no dependencies to wait for # The dispatcher is run on the remote worker and unpacks the data required # to execute the job contents. diff --git a/bsb/storage/_files.py b/bsb/storage/_files.py index 75c8f13a..f82bbac6 100644 --- a/bsb/storage/_files.py +++ b/bsb/storage/_files.py @@ -13,6 +13,7 @@ import urllib.parse as _up import urllib.request as _ur +import certifi as _cert import nrrd as _nrrd import requests as _rq @@ -229,7 +230,7 @@ def resolve_uri(self, file: FileDependency): def find(self, file: FileDependency): with self.create_session() as session: - response = session.head(self.resolve_uri(file)) + response = session.head(self.resolve_uri(file), verify=_cert.where()) return response.status_code == 200 def should_update(self, file: FileDependency, stored_file): @@ -251,12 +252,12 @@ def should_update(self, file: FileDependency, stored_file): def get_content(self, file: FileDependency): with self.create_session() as session: - response = session.get(self.resolve_uri(file)) + response = session.get(self.resolve_uri(file), verify=_cert.where()) return (response.content, response.encoding) def get_meta(self, file: FileDependency): with self.create_session() as session: - response = session.head(self.resolve_uri(file)) + response = session.head(self.resolve_uri(file), verify=_cert.where()) return {"headers": dict(response.headers)} def get_local_path(self, file: FileDependency): @@ -265,7 +266,9 @@ def get_local_path(self, file: FileDependency): @_cl.contextmanager def provide_stream(self, file): with self.create_session() as session: - response = session.get(self.resolve_uri(file), stream=True) + response = session.get( + self.resolve_uri(file), stream=True, verify=_cert.where() + ) response.raw.decode_content = True response.raw.auto_close = False yield (response.raw, response.encoding) @@ -294,11 +297,11 @@ def get_nm_meta(self, file: FileDependency): name = file.uri[idx : (idx + len(name))] with self.create_session() as session: try: - res = session.get(self._nm_url + self._meta + name) + res = session.get(self._nm_url + self._meta + name, verify=_cert.where()) except Exception as e: return {"archive": "none", "neuron_name": "none"} if res.status_code == 404: - res = session.get(self._nm_url) + res = session.get(self._nm_url, verify=_cert.where()) if res.status_code != 200 or "Service Interruption Notice" in res.text: warn(f"NeuroMorpho.org is down, can't retrieve morphology '{name}'.") return {"archive": "none", "neuron_name": "none"} @@ -400,24 +403,44 @@ def get_stored_file(self): @config.node class CodeDependencyNode(FileDependencyNode): + """ + Allow the loading of external code during network loading. + """ + module: str = config.attr(type=str, required=types.shortform()) + """Should be either the path to a python file or a import like string""" attr: str = config.attr(type=str) + """Attribute to extract from the loaded script""" @config.property def file(self): + import os + if getattr(self, "scaffold", None) is not None: file_store = self.scaffold.files else: file_store = None - return FileDependency( - self.module.replace(".", _os.sep) + ".py", file_store=file_store - ) + if os.path.isfile(self.module): + # Convert potential relative path to absolute path + module_file = os.path.abspath(os.path.join(os.getcwd(), self.module)) + else: + # Module like string converted to a path string relative to current folder + module_file = "./" + self.module.replace(".", _os.sep) + ".py" + return FileDependency(module_file, file_store=file_store) def __init__(self, module=None, **kwargs): super().__init__(**kwargs) if module is not None: self.module = module + def __inv__(self): + if not isinstance(self, CodeDependencyNode): + return self + res = {"module": getattr(self, "module")} + if self.attr is not None: + res["attr"] = self.attr + return res + def load_object(self): import importlib.util import sys @@ -429,7 +452,7 @@ def load_object(self): module = importlib.util.module_from_spec(spec) sys.modules[self.module] = module spec.loader.exec_module(module) - return module if self.attr is None else module[self.attr] + return module if self.attr is None else getattr(module, self.attr) finally: tmp = list(reversed(sys.path)) tmp.remove(_os.getcwd()) diff --git a/bsb/voxels.py b/bsb/voxels.py index bc46326f..0d8808b9 100644 --- a/bsb/voxels.py +++ b/bsb/voxels.py @@ -637,4 +637,4 @@ def _squash_zero(arr): return np.where(np.isclose(arr, 0), np.finfo(float).max, arr) -__all__ = ["BoxTree", "VoxelData", "VoxelSet"] +__all__ = ["VoxelData", "VoxelSet"] diff --git a/docs/bsb/bsb.connectivity.geometric.rst b/docs/bsb/bsb.connectivity.geometric.rst new file mode 100644 index 00000000..b7ba840a --- /dev/null +++ b/docs/bsb/bsb.connectivity.geometric.rst @@ -0,0 +1,37 @@ +bsb.connectivity.geometric package +================================== + +Submodules +---------- + +bsb.connectivity.geometric.geometric\_shapes module +--------------------------------------------------- + +.. automodule:: bsb.connectivity.geometric.geometric_shapes + :members: + :undoc-members: + :show-inheritance: + +bsb.connectivity.geometric.morphology\_shape\_intersection +---------------------------------------------------------- + +.. automodule:: bsb.connectivity.geometric.morphology_shape_intersection + :members: + :undoc-members: + :show-inheritance: + +bsb.connectivity.geometric.shape\_morphology\_intersection +---------------------------------------------------------- + +.. automodule:: bsb.connectivity.geometric.shape_morphology_intersection + :members: + :undoc-members: + :show-inheritance: + +bsb.connectivity.geometric.shape\_shape\_intersection +----------------------------------------------------- + +.. automodule:: bsb.connectivity.geometric.shape_shape_intersection + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/bsb/bsb.connectivity.rst b/docs/bsb/bsb.connectivity.rst index e0062629..5036200e 100644 --- a/docs/bsb/bsb.connectivity.rst +++ b/docs/bsb/bsb.connectivity.rst @@ -5,9 +5,10 @@ Subpackages ----------- .. toctree:: - :maxdepth: 4 + :maxdepth: 2 bsb.connectivity.detailed + bsb.connectivity.geometric Submodules ---------- diff --git a/docs/conf.py b/docs/conf.py index 6d5456b8..69f420f8 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -30,6 +30,10 @@ # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. + +autodoc_typehints = "both" + + extensions = [ "sphinx.ext.autodoc", "sphinx.ext.todo", @@ -93,6 +97,7 @@ "getting-started/layer.rst", ] +autoclass_content = "both" # -- Options for HTML output ------------------------------------------------- diff --git a/docs/connectivity/connectivity-toc.rst b/docs/connectivity/connectivity-toc.rst index 232589e9..05742ab7 100644 --- a/docs/connectivity/connectivity-toc.rst +++ b/docs/connectivity/connectivity-toc.rst @@ -5,3 +5,4 @@ defining component connection-strategies + geometric diff --git a/docs/connectivity/geometric.rst b/docs/connectivity/geometric.rst new file mode 100644 index 00000000..dac2e0fe --- /dev/null +++ b/docs/connectivity/geometric.rst @@ -0,0 +1,221 @@ +.. _geometric: + +####################################### +Geometric shape connectivity strategies +####################################### + +To reconstruct with great details the connections between 2 neurons, one needs to provide the +morphologies of these neurons. However, this data might be lacking or incomplete. +Moreover, the reconstruction of a detailed connectivity is computationally expensive as the program +have to find all apposition of the neurons arborizations. + +To resolve these two issues, neurons' morphology is here represented by a collection of geometric +shapes representing the pre/postsynaptic neurites. The neurites apposition can be probabilistically +approximated sampling point clouds from the shapes and checking the shape bounding box +(see B1 in :ref:`Bibliography`). + +Creating simplified morphologies +******************************** + +The :class:`~bsb.connectivity.geometric.geometric_shapes.ShapesComposition` allows the simplified +representation of cell morphologies. This class leverages a list geometric shapes +(:class:`~bsb.connectivity.geometric.geometric_shapes.GeometricShape`) to represent ``sections`` +the cell morphology. Similarly to morphologies, labels should be associated to each of these +``sections``. These labels will be used as reference during connectivity. + +For each ``section`` of the simplified morphology, the class samples a set of 3D points that belong +to it. This cloud of points is used to detect connections between a source and target neuron. +The points are uniformly distributed in the ``GeometricShape``, decomposing it into 3D voxels. +The program generates as many points as the number of voxels in the volume of the shapes. + +Geometric shapes +---------------- + +Pre-defined GeometricShape implemented can be found in the ``~bsb.connectivity.geometric`` package. +Each shape has its own set of parameters. We provide here an example of the configuration +for a sphere: + +.. autoconfig:: bsb.connectivity.geometric.Sphere + :no-imports: + :max-depth: 1 + +If needed, a user can define its own geometric shape, creating a new class inheriting from the base +virtual class :class:`~bsb.connectivity.geometric.geometric_shapes.GeometricShape`. + +ShapesComposition +----------------- +To instantiate a :class:`~bsb.connectivity.geometric.geometric_shapes.ShapesComposition`, you need +to provide a list of ``shapes`` together with their ``labels``: a list of lists of strings. +``shapes`` and ``labels`` should have the same size. For each shape, multiple labels can be provided. +You can additionally control the number of points sampled for connectivity with the parameter +``voxel_size``. This parameter corresponds to the side length of one voxel used to decompose the +shape collection. + +.. autoconfig:: bsb.connectivity.geometric.ShapesComposition + :no-imports: + :max-depth: 2 + +Here, we represent the cell as a single sphere for the soma, a cone for the dendrites and a cylinder +for the axon: + +.. code-block:: json + + "my_neuron": + { + "voxel_size": 25, + "shapes": + [ + { + "type": "sphere", + "radius": 40.0, + "center": [0., 0., 0.]}, + { + "type": "cone", + "center": [0., 0., 0.], + "radius": 100.0, + "apex": [0., 100., 0.]}, + { + "type": "cylinder", + "radius": 100.0, + "top_center": [0., 0., 0.], + "bottom_center": [0., 0., 10.] + } + ], + "labels": + [ + ["soma"], + ["basal_dendrites", "apical_dendrites"], + ["axon"] + ], + } + +Geometric shape connectivity +**************************** + +The configuration of the geometric shape strategies are similar to the other connectivity strategies +(see :class:`~bsb.connectivity.detailed.voxel_intersection.VoxelIntersection`). + +The ``ShapesComposition`` configuration should be provided with the field ``shape_compositions`` in +the pre- and/or postsynaptic field (dependant on the strategy chosen). + +The parameters ``morphology_labels`` here specifies which shapes of the ``shape_compositions`` in +:class:`~bsb.connectivity.geometric.geometric_shapes.ShapesComposition` must be used +(corresponds to values stored in ``labels``). + +The ``affinity`` parameter controls the probability to form a connection. +Three different connectivity strategies based on ``ShapesComposition`` are available. + +MorphologyToShapeIntersection +----------------------------- + +The class :class:`~bsb.connectivity.geometric.morphology_shape_intersection.MorphologyToShapeIntersection` +creates connections between the points of the morphology of the presynaptic cell and a geometric shape composition +representing a postsynaptic cell, checking if the points of the morphology are inside the geometric +shapes representing the postsynaptic cells. +This connection strategy is suitable when we have a detailed morphology of the presynaptic cell, but +not of the postsynaptic cell. + +Configuration example: + +.. code-block:: json + + "stellate_to_purkinje": + { + "strategy": "bsb.connectivity.MorphologyToShapeIntersection", + "presynaptic": { + "cell_types": ["stellate_cell"], + "morphology_labels": ["axon"], + }, + "postsynaptic": { + "cell_types": ["purkinje_cell"], + "morphology_labels": ["sc_targets"], + "shape_compositions" : [{ + "voxel_size": 25, + "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}], + "labels": [["soma", "dendrites", "sc_targets", "axon"]], + }] + }, + "affinity": 0.1, + "pruning_ratio": 0.5 + } + +ShapeToMorphologyIntersection +----------------------------- + +The class :class:`~bsb.connectivity.geometric.shape_morphology_intersection.ShapeToMorphologyIntersection` +creates connections between the point cloud representing the presynaptic cell the points of the +morphology of a postsynaptic cell, checking if the points of the morphology are inside the +geometric shapes representing the presynaptic cells. +This connection strategy is suitable when we have a detailed morphology of the postsynaptic cell, +but not of the presynaptic cell. + +Configuration example: + +.. code-block:: json + + "stellate_to_purkinje": + { + "strategy": "bsb.connectivity.ShapeToMorphologyIntersection", + "presynaptic": { + "cell_types": ["stellate_cell"], + "morphology_labels": ["axon"], + "shape_compositions" : [{ + "voxel_size": 25, + "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}], + "labels": [["soma", "dendrites", "axon"]], + }] + }, + "postsynaptic": { + "cell_types": ["purkinje_cell"], + "morphology_labels": ["sc_targets"] + }, + "affinity": 0.1, + "pruning_ratio": 0.5 + } + +ShapeToShapeIntersection +------------------------ + +The class :class:`~bsb.connectivity.geometric.shape_shape_intersection.ShapeToShapeIntersection` +creates connections between the geometric shape compositions representing the presynaptic and postsynaptic cells. +This strategy forms a connections generating a number of points inside the presynaptic probability +point cloud and checking if they are inside the geometric shapes representing the postsynaptic cell. +One point per voxel is generated. +This connection strategy is suitable when we do not have a detailed morphology of neither the +presynaptic nor the postsynaptic cell. + +Configuration example: + +.. code-block:: json + + "stellate_to_purkinje": + { + "strategy": "bsb.connectivity.ShapeToShapeIntersection", + "presynaptic": { + "cell_types": ["stellate_cell"], + "morphology_labels": ["axon"], + "shape_compositions" : [{ + "voxel_size": 25, + "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}], + "labels": [["soma", "dendrites", "axon"]], + }] + }, + "postsynaptic": { + "cell_types": ["purkinje_cell"], + "morphology_labels": ["sc_targets"], + "shape_compositions" : [{ + "voxel_size": 25, + "shapes": [{"type": "sphere", "radius": 40.0, "center": [0., 0., 0.]}], + "labels": [["soma", "dendrites", "sc_targets", "axon"]], + }] + }, + "affinity": 0.1, + "pruning_ratio": 0.7, + } + +.. _Bibliography: + +Bibliography +************ + +* B1: Gandolfi D, Mapelli J, Solinas S, De Schepper R, Geminiani A, Casellato C, D'Angelo E, Migliore M. A realistic morpho-anatomical connection strategy for modelling full-scale point-neuron microcircuits. Sci Rep. 2022 Aug 16;12(1):13864. doi: 10.1038/s41598-022-18024-y. Erratum in: Sci Rep. 2022 Nov 17;12(1):19792. PMID: 35974119; PMCID: PMC9381785. \ No newline at end of file diff --git a/docs/getting-started/installation.rst b/docs/getting-started/installation.rst index a114da3e..8bb38f42 100644 --- a/docs/getting-started/installation.rst +++ b/docs/getting-started/installation.rst @@ -11,7 +11,7 @@ The BSB framework can be installed using ``pip``: .. code-block:: bash - pip install "bsb~=4.0" + pip install "bsb~=4.1" You can verify that the installation works with: @@ -47,7 +47,7 @@ To then install the BSB with parallel MPI support: .. code-block:: bash - pip install "bsb[parallel]~=4.0" + pip install "bsb[parallel]~=4.1" Simulator backends ================== diff --git a/pyproject.toml b/pyproject.toml index ba9620e4..de0666b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ line-length = 90 profile = "black" [tool.bumpversion] -current_version = "4.0.1" +current_version = "4.1.1" parse = "(?P\\d+)\\.(?P\\d+)\\.(?P\\d+)" serialize = ["{major}.{minor}.{patch}"] search = "{current_version}" diff --git a/tests/data/code_dependency.py b/tests/data/code_dependency.py new file mode 100644 index 00000000..1feec292 --- /dev/null +++ b/tests/data/code_dependency.py @@ -0,0 +1 @@ +from bsb_test.configs.double_neuron import tree diff --git a/tests/data/configs/basics.txt b/tests/data/configs/basics.txt new file mode 100644 index 00000000..bf60cc48 --- /dev/null +++ b/tests/data/configs/basics.txt @@ -0,0 +1,5 @@ +{ + "hello": "world", + "list": [1, 2, 3, "waddup"], + "nest me hard": {"oh yea": "just like that"}, +} \ No newline at end of file diff --git a/tests/data/configs/doubleref.txt b/tests/data/configs/doubleref.txt new file mode 100644 index 00000000..8bf67684 --- /dev/null +++ b/tests/data/configs/doubleref.txt @@ -0,0 +1,8 @@ +{ + "refs": { + "whats the": { + "$ref": "basics.txt#/nest me hard", + "$ref": "indoc_reference.txt#/target" + } + } +} \ No newline at end of file diff --git a/tests/data/configs/far/targetme.bla b/tests/data/configs/far/targetme.bla new file mode 100644 index 00000000..36fc2511 --- /dev/null +++ b/tests/data/configs/far/targetme.bla @@ -0,0 +1,25 @@ +< + "this": < + "key": < + "was": "in another folder", + "can": < + "i": < + "be": "imported" + > + >, + "with": < + "this": "string" + >, + "in": < + "my": "dict", + "or": < + "will": "it", + "give": "an", + "error": < + "on": "import" + > + > + > + > + > +> diff --git a/tests/data/configs/indoc_import.txt b/tests/data/configs/indoc_import.txt new file mode 100644 index 00000000..647ce395 --- /dev/null +++ b/tests/data/configs/indoc_import.txt @@ -0,0 +1,20 @@ +{ + "arr": { + "with": {}, + "many": {}, + "importable": { + "dicts": { + "that": "are", + "even": { + "nested": "eh" + } + } + } + }, + "imp": { + "$import": { + "ref": "#/arr", + "values": ["with", "importable"] + } + } +} \ No newline at end of file diff --git a/tests/data/configs/indoc_import_merge.txt b/tests/data/configs/indoc_import_merge.txt new file mode 100644 index 00000000..65395735 --- /dev/null +++ b/tests/data/configs/indoc_import_merge.txt @@ -0,0 +1,28 @@ +{ + "arr": { + "with": {}, + "many": {}, + "importable": { + "dicts": { + "that": "are", + "even": { + "nested": "eh" + }, + "with": ["l", "i", "s", "t", "s"] + } + } + }, + "imp": { + "$import": { + "ref": "#/arr", + "values": ["with", "importable"] + }, + "importable": { + "diff": "added", + "dicts": { + "that": 4, + "with": ["new", "list"] + } + } + } +} \ No newline at end of file diff --git a/tests/data/configs/indoc_reference.txt b/tests/data/configs/indoc_reference.txt new file mode 100644 index 00000000..634f15be --- /dev/null +++ b/tests/data/configs/indoc_reference.txt @@ -0,0 +1,22 @@ +{ + "get": { + "a": { + "secret": "key", + "nested secrets": { + "vim": "is hard", + "and": "convoluted" + } + } + }, + "refs": { + "whats the": { + "$ref": "#/get/a" + }, + "omitted_doc": { + "$ref": "/get/a" + } + }, + "target": { + "for": "another" + } +} \ No newline at end of file diff --git a/tests/data/configs/outdoc_import_merge.txt b/tests/data/configs/outdoc_import_merge.txt new file mode 100644 index 00000000..6ea5abb7 --- /dev/null +++ b/tests/data/configs/outdoc_import_merge.txt @@ -0,0 +1,14 @@ +{ + "imp": { + "importable":{ + "$import": { + "ref": "indoc_import.txt#/imp/importable", + "values": ["dicts"] + }, + } + }, + "$import": { + "ref": "indoc_import_merge.txt#/", + "values": ["imp"] + } +} diff --git a/tests/test_configuration.py b/tests/test_configuration.py index 6c1c2374..d890678a 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -1,5 +1,6 @@ import inspect import json +import os.path import sys import unittest @@ -10,12 +11,14 @@ get_test_config, list_test_configs, ) +from bsb_test.configs import get_test_config_module import bsb from bsb import ( CastError, CfgReferenceError, ClassMapMissingError, + CodeDependencyNode, ConfigurationError, ConfigurationWarning, DynamicClassInheritanceError, @@ -28,6 +31,7 @@ UnfitClassCastError, UnresolvedClassCastError, config, + from_storage, ) from bsb._package_spec import get_missing_requirement_reason from bsb.config import Configuration, _attrs, compose_nodes, types @@ -1253,6 +1257,37 @@ class TestClass: with self.assertRaises(RequirementError): TestClass(a="1", b="6", c="3") + def test_code_dependency_node(self): + @config.node + class Test: + c = config.attr(type=CodeDependencyNode) + + module = get_test_config_module("double_neuron") + script = str(module.__file__) + # Test with a module like string + import_str = os.path.relpath( + os.path.join(os.path.dirname(__file__), "data/code_dependency") + ).replace(os.sep, ".") + b = Test( + c=import_str, + _parent=TestRoot(), + ) + self.assertEqual(b.c.load_object().tree, module.tree) + # test with a file + b = Test( + c={"module": script}, + _parent=TestRoot(), + ) + # Test variable tree inside the file. + self.assertEqual(b.c.load_object().tree, module.tree) + self.assertEqual(b.__tree__(), {"c": {"module": script}}) + # Test with relative path + b = Test( + c={"module": os.path.relpath(script), "attr": "tree"}, + _parent=TestRoot(), + ) + self.assertEqual(b.c.load_object(), module.tree) + @config.dynamic( type=types.in_classmap(), @@ -1649,7 +1684,7 @@ def test_composite_node(self): assert type(self.tested.attrC == config.ConfigurationAttribute) -class TestPackageRequirements(unittest.TestCase): +class TestPackageRequirements(RandomStorageFixture, unittest.TestCase, engine_name="fs"): def test_basic_version(self): self.assertIsNone(get_missing_requirement_reason("bsb-core==" + bsb.__version__)) @@ -1669,3 +1704,15 @@ def test_uninstalled_package(self): self.assertIsNotNone(get_missing_requirement_reason("bsb-core-soup==4.0")) with self.assertWarns(PackageRequirementWarning): Configuration.default(packages=["bsb-core-soup==4.0"]) + + def test_installed_package(self): + self.assertIsNone(get_missing_requirement_reason(f"bsb-core~={bsb.__version__}")) + # Should produce no warnings + cfg = Configuration.default(packages=[f"bsb-core~={bsb.__version__}"]) + # Checking that the config with package requirements can be saved in storage + self.network = Scaffold(cfg, self.storage) + # Checking if the config + network2 = from_storage(self.storage.root) + self.assertEqual( + self.network.configuration.packages, network2.configuration.packages + ) diff --git a/tests/test_connectivity.py b/tests/test_connectivity.py index c864e3a9..0705dbb7 100644 --- a/tests/test_connectivity.py +++ b/tests/test_connectivity.py @@ -457,7 +457,6 @@ def connect_spy(strat, pre, post): ) -@unittest.skip("https://github.com/dbbs-lab/bsb-core/issues/820") class TestVoxelIntersection( RandomStorageFixture, NetworkFixture, @@ -680,7 +679,7 @@ def test_multi_indegree(self): self.network.compile() for post_name in ("inhibitory", "extra"): post_ps = self.network.get_placement_set(post_name) - total = np.zeros(len(post_ps)) + total = np.zeros(len(post_ps), dtype=int) for pre_name in ("excitatory", "extra"): cs = self.network.get_connectivity_set( f"multi_indegree_{pre_name}_to_{post_name}" @@ -688,7 +687,7 @@ def test_multi_indegree(self): _, post_locs = cs.load_connections().all() ps = self.network.get_placement_set("inhibitory") u, c = np.unique(post_locs[:, 0], return_counts=True) - this = np.zeros(len(post_ps)) + this = np.zeros(len(post_ps), dtype=int) this[u] = c total += this self.assertTrue(np.all(total == 50), "Not all cells have indegree 50") diff --git a/tests/test_geometric_connectivity.py b/tests/test_geometric_connectivity.py new file mode 100644 index 00000000..0e2d7004 --- /dev/null +++ b/tests/test_geometric_connectivity.py @@ -0,0 +1,214 @@ +import unittest + +from bsb_test import ( + FixedPosConfigFixture, + MorphologiesFixture, + NetworkFixture, + NumpyTestCase, + RandomStorageFixture, +) + +from bsb import Configuration, DatasetNotFoundError, Scaffold +from bsb.connectivity import ( + MorphologyToShapeIntersection, + ShapeToMorphologyIntersection, + ShapeToShapeIntersection, +) + + +class TestShapeConnectivity( + RandomStorageFixture, + FixedPosConfigFixture, + NetworkFixture, + MorphologiesFixture, + NumpyTestCase, + unittest.TestCase, + engine_name="hdf5", + morpho_filters=["2branch"], +): + def setUp(self): + super().setUp() + + self.cfg = Configuration.default( + cell_types=dict( + test_cell_morpho=dict( + spatial=dict( + radius=1, density=1, morphologies=[dict(names=["2branch"])] + ) + ), + test_cell_pc_1=dict(spatial=dict(radius=1, density=1)), + test_cell_pc_2=dict(spatial=dict(radius=1, density=1)), + ), + placement=dict( + fixed_pos_morpho=dict( + strategy="bsb.placement.FixedPositions", + cell_types=["test_cell_morpho"], + partitions=[], + positions=[[0, 0, 0], [0, 0, 100], [50, 0, 0], [0, -100, 0]], + ), + fixed_pos_pc_1=dict( + strategy="bsb.placement.FixedPositions", + cell_types=["test_cell_pc_1"], + partitions=[], + positions=[[40, 40, 40]], + ), + fixed_pos_pc_2=dict( + strategy="bsb.placement.FixedPositions", + cell_types=["test_cell_pc_2"], + partitions=[], + positions=[[0, -100, 0]], + ), + ), + ) + + self.network = Scaffold(self.cfg, self.storage) + self.network.compile(skip_connectivity=True) + + def test_shape_to_shape(self): + voxel_size = 25 + config_sphere = dict(type="sphere", radius=40.0, origin=[0, 0, 0]) + ball_shape = { + "voxel_size": voxel_size, + "shapes": [config_sphere], + "labels": [["sphere"]], + } + # All the points of the presyn shape are inside the postsyn shape + self.network.connectivity["shape_to_shape_1"] = ShapeToShapeIntersection( + presynaptic=dict( + cell_types=["test_cell_pc_1"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + postsynaptic=dict( + cell_types=["test_cell_pc_1"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + affinity=0.9, + pruning_ratio=0.1, + ) + + # There are no intersections between the presyn and postsyn shapes + self.network.connectivity["shape_to_shape_2"] = ShapeToShapeIntersection( + presynaptic=dict( + cell_types=["test_cell_pc_1"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + postsynaptic=dict( + cell_types=["test_cell_pc_2"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + affinity=0.9, + pruning_ratio=0.1, + ) + + self.network.compile(skip_placement=True, append=True) + + cs = self.network.get_connectivity_set("shape_to_shape_1") + con = cs.load_connections().all()[0] + intersection_points = len(con) + self.assertGreater( + intersection_points, + 0, + "expected at least one intersection point", + ) + + with self.assertRaises(DatasetNotFoundError): + # No connectivity set expected because no overlap of the populations' chunks. + self.network.get_connectivity_set("shape_to_shape_2") + + def test_shape_to_morpho(self): + voxel_size = 25 + config_sphere = dict(type="sphere", radius=40.0, origin=[0, 0, 0]) + ball_shape = { + "voxel_size": voxel_size, + "shapes": [config_sphere], + "labels": [["sphere"]], + } + + # We know a priori that there are intersections between the presyn shape and the morphology + self.network.connectivity["shape_to_morpho_1"] = ShapeToMorphologyIntersection( + presynaptic=dict( + cell_types=["test_cell_pc_2"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + postsynaptic=dict(cell_types=["test_cell_morpho"]), + affinity=0.5, + pruning_ratio=0.5, + ) + + # There are no intersections between the presyn shape and the morpho + self.network.connectivity["shape_to_morpho_2"] = ShapeToMorphologyIntersection( + presynaptic=dict( + cell_types=["test_cell_pc_1"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + postsynaptic=dict(cell_types=["test_cell_morpho"]), + affinity=0.5, + pruning_ratio=0.5, + ) + + self.network.compile(skip_placement=True, append=True) + + cs = self.network.get_connectivity_set("shape_to_morpho_1") + con = cs.load_connections().all()[0] + intersection_points = len(con) + self.assertGreater( + intersection_points, 0, "expected at least one intersection point" + ) + + cs = self.network.get_connectivity_set("shape_to_morpho_2") + con = cs.load_connections().all()[0] + intersection_points = len(con) + self.assertClose(0, intersection_points, "expected no intersection points") + + def test_morpho_to_shape(self): + voxel_size = 25 + config_sphere = dict(type="sphere", radius=40.0, origin=[0, 0, 0]) + ball_shape = { + "voxel_size": voxel_size, + "shapes": [config_sphere], + "labels": [["sphere"]], + } + + # We know a priori that there are intersections between the presyn shape and the morphology + self.network.connectivity["shape_to_morpho_1"] = MorphologyToShapeIntersection( + postsynaptic=dict( + cell_types=["test_cell_pc_2"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + presynaptic=dict(cell_types=["test_cell_morpho"]), + affinity=0.5, + pruning_ratio=0.5, + ) + + # There are no intersections between the presyn shape and the morphology. + self.network.connectivity["shape_to_morpho_2"] = MorphologyToShapeIntersection( + postsynaptic=dict( + cell_types=["test_cell_pc_1"], + shapes_composition=ball_shape, + morphology_labels=["soma"], + ), + presynaptic=dict(cell_types=["test_cell_morpho"]), + affinity=0.5, + pruning_ratio=0.5, + ) + + self.network.compile(skip_placement=True, append=True) + + cs = self.network.get_connectivity_set("shape_to_morpho_1") + con = cs.load_connections().all()[0] + intersection_points = len(con) + self.assertGreater( + intersection_points, 0, "expected at least one intersection point" + ) + + cs = self.network.get_connectivity_set("shape_to_morpho_2") + con = cs.load_connections().all()[0] + intersection_points = len(con) + self.assertClose(0, intersection_points, "expected no intersection points") diff --git a/tests/test_geometric_shapes.py b/tests/test_geometric_shapes.py new file mode 100644 index 00000000..e6b5aed3 --- /dev/null +++ b/tests/test_geometric_shapes.py @@ -0,0 +1,707 @@ +import unittest + +import numpy as np +from bsb_test import NumpyTestCase + +from bsb import RequirementError +from bsb.connectivity import ( + Cone, + Cuboid, + Cylinder, + Ellipsoid, + Parallelepiped, + ShapesComposition, + Sphere, +) + + +class TestGeometricShapes(unittest.TestCase, NumpyTestCase): + def _check_points_inside(self, sc, volume, voxel_size): + self.assertClose(np.sum(sc.get_volumes()), volume) + expected_number_of_points = int(volume / voxel_size**3) + point_cloud = sc.generate_point_cloud() + npoints = len(point_cloud) + + # Check the number of points in the point cloud + self.assertEqual( + npoints, + expected_number_of_points, + "The number of point in the point cloud is not the expected one", + ) + + # Check if the point cloud is inside the sphere + points_inside_sphere = sc.inside_shapes(point_cloud) + all_points_inside = np.all(points_inside_sphere) + self.assertEqual( + all_points_inside, + True, + "The point cloud should be inside the ShapeComposition", + ) + + def _check_translation(self, sc, expected_mbb): + # Check translation + translation_vec = np.array([1.0, 10.0, 100.0]) + sc.translate(translation_vec) + mbb = sc.find_mbb() + expected_mbb += translation_vec + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + sc.translate(-translation_vec) + expected_mbb -= translation_vec + + # Create a sphere, add it to a ShapeComposition object and test the minimal bounding box, + # inside_mbox, inside_shapes and generate_point_cloud methods + def test_sphere(self): + # Create a ShapesComposition object; In this test the size of the voxel is not important. + conf = dict(voxel_size=25, shapes=[], labels=[]) + sc = ShapesComposition(conf) + self.assertEqual(None, sc.generate_point_cloud()) + self.assertEqual(None, sc.inside_shapes(np.array([[0.0, 0.0, 0.0]]))) + self.assertEqual(None, sc.generate_wireframe()) + self.assertEqual([], sc.get_volumes()) + + # Add the sphere to the ShapesComposition object + radius = 100.0 + origin = np.array([0, 0, 0], dtype=np.float64) + configuration = dict(radius=radius, origin=origin) + sc.add_shape(Sphere(configuration), ["sphere"]) + + # Find the mmb + mbb = sc.find_mbb() + expected_mbb = np.array( + [[-100.0, -100.0, -100.0], [100.0, 100.0, 100.0]], dtype=np.float64 + ) + + # If the result is correct the mmb is the box individuated by + # the opposite vertices [-100., -100., 0.] and [100., 100., 100.]. + # The tuple must be in the correct order. + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # The point of coordinates (0,0,50) is inside the sphere, while (200,200,200) is not. + # We test check_mbox, check_inside with these two points + point_to_check = np.array([[0, 0, 50], [200, 200, 200]], dtype=np.float64) + inside_mbox = sc.inside_mbox(point_to_check) + inside_shape = sc.inside_shapes(point_to_check) + + # Check if the points are inside the mbb + expected_inside_mbox = [True, False] + self.assertEqual( + inside_mbox[0], + expected_inside_mbox[0], + "The point (0,0,50) should be inside the minimal bounding box", + ) + self.assertEqual( + inside_mbox[1], + expected_inside_mbox[1], + "The point (200,200,200) should be outside the minimal bounding box", + ) + + # Check if the points are inside the sphere + expected_inside_shape = [True, False] + self.assertEqual( + inside_shape[0], + expected_inside_shape[0], + "The point (0,0,50) should be inside the sphere", + ) + self.assertEqual( + inside_shape[1], + expected_inside_shape[1], + "The point (0,0,-50) should be outside the sphere", + ) + + # Test generate_point_cloud method. + # The expected number of points is given by the volume of the sphere divided by the voxel + # side to the third. + # The points should be inside the sphere. + volume = 4 * (np.pi * configuration["radius"] ** 3) / 3.0 + self._check_points_inside(sc, volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) + wireframe = np.array(sc.generate_wireframe()) + self.assertEqual((3, 1, 30, 30), wireframe.shape) + self.assertTrue( + np.allclose(np.linalg.norm(wireframe[:, 0, :, 0].T - origin, axis=1), radius) + ) + + # Create an ellipsoid, add it to a ShapeComposition object and test the minimal bounding box, + # inside_mbox, inside_shapes and generate_point_cloud methods + def test_ellipsoid(self): + # Create a ShapesComposition object; In this test the size of the voxel is not important. + conf = dict(voxel_size=25, shapes=[], labels=[]) + sc = ShapesComposition(conf) + + # Add the ellipsoid to the ShapesComposition object + configuration = dict( + origin=np.array([0, 0, 0], dtype=np.float64), + lambdas=np.array([50, 100, 10], dtype=np.float64), + v0=np.array([1, 0, 0], dtype=np.float64), + v1=np.array([0, 1, 0], dtype=np.float64), + v2=np.array([0, 0, 1], dtype=np.float64), + ) + ellipsoid = Ellipsoid(configuration) + sc.add_shape(ellipsoid, ["ellipsoid"]) + + # Find the mmb + mbb = sc.find_mbb() + expected_mbb = np.array( + [[-50.0, -100.0, -10.0], [50.0, 100.0, 10.0]], dtype=np.float64 + ) + + # If the result is correct the mmb is the box individuated by + # the opposite vertices [-100., -100., -10.] and [100., 100., 10.]. + # The tuple must be in the correct order. + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # The point of coordinates (0,0,5) is inside the ellipsoid, while (20,20,20) is not. + # We test check_mbox, check_inside with these two points + point_to_check = np.array([[0, 0, 5], [20, 20, 20]], dtype=np.float64) + inside_mbox = sc.inside_mbox(point_to_check) + inside_shape = sc.inside_shapes(point_to_check) + + # Check if the points are inside the mbb + expected_inside_mbox = [True, False] + self.assertEqual( + inside_mbox[0], + expected_inside_mbox[0], + "The point (0,0,50) should be inside the minimal bounding box", + ) + self.assertEqual( + inside_mbox[1], + expected_inside_mbox[1], + "The point (200,200,200) should be outside the minimal bounding box", + ) + + # Check if the points are inside the ellipsoid + expected_inside_shape = [True, False] + self.assertEqual( + inside_shape[0], + expected_inside_shape[0], + "The point (0,0,50) should be inside the ellipsoid", + ) + self.assertEqual( + inside_shape[1], + expected_inside_shape[1], + "The point (0,0,-50) should be outside the ellipsoid", + ) + + # Test generate_point_cloud method. + # The expected number of points is given by the volume of the ellipsoid divided by the voxel side to the third + # The points should be inside the ellipsoid. + volume = ( + np.pi + * configuration["lambdas"][0] + * configuration["lambdas"][1] + * configuration["lambdas"][2] + ) + self._check_points_inside(sc, volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) + + # Check rotation + ellipsoid.rotate(np.array([0.0, 0.0, 1.0]), np.pi / 2) + mbb = ellipsoid.find_mbb() + self.assertClose(mbb[0], expected_mbb[0, [1, 0, 2]]) + self.assertClose(mbb[1], expected_mbb[1, [1, 0, 2]]) + + wireframe = np.array(sc.generate_wireframe()) + self.assertEqual((3, 1, 30, 30), wireframe.shape) + for coord in (wireframe[:, 0, :, 0] - ellipsoid.origin[..., np.newaxis]).T: + self.assertTrue(-1e-5 <= coord[0] <= 1e5) + self.assertTrue(-1e-5 <= coord[1] <= 50) + self.assertTrue(-10 <= coord[2] <= 10) + + # Create a cylinder, add it to a ShapeComposition object and test the minimal bounding box, + # inside_mbox, inside_shapes and generate_point_cloud methods + def test_cylinder(self): + # Create a ShapesComposition object; In this test the size of the voxel is not important. + conf = dict(voxel_size=25, shapes=[], labels=[]) + sc = ShapesComposition(conf) + + # Add the cylinder to the ShapesComposition object + radius = 100.0 + origin = np.array([0, 0, 0], dtype=np.float64) + top_center = np.array([0, 0, 10], dtype=np.float64) + + configuration = dict( + radius=radius, + origin=origin, + top_center=top_center, + ) + cylinder = Cylinder(configuration) + sc.add_shape(cylinder, ["cylinder"]) + + # Find the mmb + mbb = sc.find_mbb() + expected_mbb = np.array( + [[-100.0, -100.0, 0.0], [100.0, 100.0, 10.0]], dtype=np.float64 + ) + + # If the result is correct the mmb is the box individuated by + # the opposite vertices [-100., -100., 0.] and [100., 100., 100.]. + # The tuple must be in the correct order. + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # The point of coordinates (0,0,5) is inside the cylinder, while (200,200,200) is not. + # We test check_mbox, check_inside with these two points + point_to_check = np.array([[0, 0, 5], [200, 200, 200]], dtype=np.float64) + inside_mbox = sc.inside_mbox(point_to_check) + inside_shape = sc.inside_shapes(point_to_check) + + # Check if the points are inside the mbb + expected_inside_mbox = [True, False] + self.assertEqual( + inside_mbox[0], + expected_inside_mbox[0], + "The point (0,0,50) should be inside the minimal bounding box", + ) + self.assertEqual( + inside_mbox[1], + expected_inside_mbox[1], + "The point (200,200,200) should be outside the minimal bounding box", + ) + + # Check if the points are inside the cylinder + expected_inside_shape = [True, False] + self.assertEqual( + inside_shape[0], + expected_inside_shape[0], + "The point (0,0,50) should be inside the cylinder", + ) + self.assertEqual( + inside_shape[1], + expected_inside_shape[1], + "The point (0,0,-50) should be outside the cylinder", + ) + + # Test generate_point_cloud method. + # The expected number of points is given by the volume of the cylinder divided by the voxel side to the third + # The points should be inside the cylinder. + height = np.linalg.norm(configuration["top_center"] - configuration["origin"]) + volume = np.pi * height * configuration["radius"] ** 2 + self._check_points_inside(sc, volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) + + # Check rotation + cylinder.rotate(np.array([1.0, 0.0, 0.0]), np.pi / 2) + mbb = cylinder.find_mbb() + self.assertClose(mbb[0], [-100.0, -10.0, -100.0]) + self.assertClose(mbb[1], [100.0, 0.0, 100.0]) + + wireframe = np.array(sc.generate_wireframe()) + self.assertEqual((3, 1, 30, 30), wireframe.shape) + self.assertTrue( + np.allclose( + np.linalg.norm( + (wireframe[:, 0, :, 0] - cylinder.origin[..., np.newaxis])[ + np.array([0, 2]) + ], + axis=0, + ), + radius, + ) + ) + for p in wireframe[:, 0, :, 0].T: + self.assertTrue( + 0 + <= np.absolute(p[1] - cylinder.origin[1]) + <= np.absolute(cylinder.top_center[1] - cylinder.origin[1]) + ) + + # Create a parallelepiped, add it to a ShapeComposition object and test the minimal bounding + # box, inside_mbox, inside_shapes and generate_point_cloud methods + def test_parallelepiped(self): + # Create a ShapesComposition object; In this test the size of the voxel is not important. + conf = dict(voxel_size=5, shapes=[], labels=[]) + sc = ShapesComposition(conf) + + # Add the parallelepiped to the ShapesComposition object + configuration = dict( + origin=np.array([-5, -5, -5], dtype=np.float64), + side_vector_1=np.array([10, 0, 0], dtype=np.float64), + side_vector_2=np.array([0, 100, 0], dtype=np.float64), + side_vector_3=np.array([0, 0, 10], dtype=np.float64), + ) + parallelepiped = Parallelepiped(configuration) + sc.add_shape( + parallelepiped, + ["parallelepiped"], + ) + + # Find the mmb + mbb = sc.find_mbb() + expected_mbb = np.array([[-5.0, -5.0, -5.0], [5.0, 95.0, 5.0]], dtype=np.float64) + + # If the result is correct the mmb is the box individuated by + # the opposite vertices [-5., -5., -5.] and [5., 5., 5.]. + # The tuple must be in the correct order. + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # The point of coordinates (0,0,0) is inside the parallelepiped, while (10,10,10) is not. + # We test check_mbox, check_inside with these two points + point_to_check = np.array([[0, 0, 0], [10, 100, 10]], dtype=np.float64) + inside_mbox = sc.inside_mbox(point_to_check) + inside_shape = sc.inside_shapes(point_to_check) + + # Check if the points are inside the mbb + expected_inside_mbox = [True, False] + self.assertEqual( + inside_mbox[0], + expected_inside_mbox[0], + "The point (0,0,0) should be inside the minimal bounding box", + ) + self.assertEqual( + inside_mbox[1], + expected_inside_mbox[1], + "The point (10,10,10) should be outside the minimal bounding box", + ) + + # Check if the points are inside the parallelepiped + expected_inside_shape = [True, False] + self.assertEqual( + inside_shape[0], + expected_inside_shape[0], + "The point (0,0,0) should be inside the parallelepiped", + ) + self.assertEqual( + inside_shape[1], + expected_inside_shape[1], + "The point (10,100,10) should be outside the parallelepiped", + ) + + # Test generate_point_cloud method. + # The expected number of points is given by the volume of the parallelepiped divided by the + # voxel side to the third. + # The points should be inside the parallelepiped. + volume = ( + np.linalg.norm(configuration["side_vector_1"]) + * np.linalg.norm(configuration["side_vector_2"]) + * np.linalg.norm(configuration["side_vector_3"]) + ) + self._check_points_inside(sc, volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) + + # Check rotation + parallelepiped.rotate(np.array([1.0, 0.0, 0.0]), np.pi / 2) + mbb = parallelepiped.find_mbb() + expected_mbb = np.array( + [[-5.0, -15.0, -5.0], [5.0, -5.0, 95.0]], dtype=np.float64 + ) + self.assertClose(mbb[0], expected_mbb[0]) + self.assertClose(mbb[1], expected_mbb[1]) + wireframe = np.array(sc.generate_wireframe()) + self.assertEqual((3, 1, 4, 4), wireframe.shape) + self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[0] >= -1e-5)) + self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[1] <= 1e-5)) + + # Create a cuboid, add it to a ShapeComposition object and test the minimal bounding box, + # inside_mbox, inside_shapes and generate_point_cloud methods + def test_cuboid(self): + # Create a ShapesComposition object; In this test the size of the voxel is not important. + conf = dict(voxel_size=25, shapes=[], labels=[]) + sc = ShapesComposition(conf) + + # Add the cuboid to the ShapesComposition object + configuration = dict( + origin=np.array([0, 0, 0], dtype=np.float64), + side_length_1=5.0, + side_length_2=10.0, + top_center=np.array([0, 0, 20], dtype=np.float64), + ) + cuboid = Cuboid(configuration) + sc.add_shape(cuboid, ["cuboid"]) + + # Find the mmb + mbb = sc.find_mbb() + expected_mbb = np.array([[-2.5, -5.0, 0.0], [2.5, 5.0, 20.0]], dtype=np.float64) + + # If the result is correct the mmb is the box individuated by + # the opposite vertices [-5., -5., -5.] and [5., 5., 5.]. + # The tuple must be in the correct order. + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # The point of coordinates (1,1,1) is inside the cuboid, while (10,10,10) is not. + # We test check_mbox, check_inside with these two points + point_to_check = np.array([[1, 1, 1], [10, 10, 10]], dtype=np.float64) + inside_mbox = sc.inside_mbox(point_to_check) + inside_shape = sc.inside_shapes(point_to_check) + + # Check if the points are inside the mbb + expected_inside_mbox = [True, False] + self.assertEqual( + inside_mbox[0], + expected_inside_mbox[0], + "The point (0,0,0) should be inside the minimal bounding box", + ) + self.assertEqual( + inside_mbox[1], + expected_inside_mbox[1], + "The point (10,10,10) should be outside the minimal bounding box", + ) + + # Check if the points are inside the cuboid + expected_inside_shape = [True, False] + self.assertEqual( + inside_shape[0], + expected_inside_shape[0], + "The point (0,0,0) should be inside the cuboid", + ) + self.assertEqual( + inside_shape[1], + expected_inside_shape[1], + "The point (10,10,10) should be outside the cuboid", + ) + + # Test generate_point_cloud method. + # The expected number of points is given by the volume of the cuboid divided by the voxel side to the third + # The points should be inside the cuboid. + volume = ( + configuration["side_length_1"] + * configuration["side_length_2"] + * np.linalg.norm(configuration["top_center"]) + ) + self._check_points_inside(sc, volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) + + # Check rotation + expected_mbb = np.array([[-2.5, -5.0, 0.0], [2.5, 5.0, 20.0]], dtype=np.float64) + cuboid.rotate(np.array([0.0, 0.0, 1.0]), np.pi / 2) + mbb = cuboid.find_mbb() + self.assertClose(mbb[0], expected_mbb[0]) + self.assertClose(mbb[1], expected_mbb[1]) + wireframe = np.array(sc.generate_wireframe()) + self.assertEqual((3, 1, 4, 4), wireframe.shape) + self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[0] >= -1e-5)) + self.assertTrue(np.alltrue(wireframe.reshape(3, 16).T - expected_mbb[1] <= 1e-5)) + + # Create a cone, add it to a ShapeComposition object and test the minimal bounding box, + # inside_mbox, inside_shapes and generate_point_cloud methods + def test_cone(self): + # Create a ShapesComposition object; In this test the size of the voxel is not important. + conf = dict(voxel_size=50, shapes=[], labels=[]) + sc = ShapesComposition(conf) + + # Add the cone to the ShapesComposition object + radius = 100.0 + origin = np.array([0, 0, 100], dtype=np.float64) + apex = np.array([0, 0, 0], dtype=np.float64) + configuration = { + "origin": origin, + "radius": radius, + "apex": apex, + } + cone = Cone(configuration) + sc.add_shape(cone, ["cone"]) + + # Find the mmb + mbb = sc.find_mbb() + expected_mbb = np.array( + [[-100.0, -100.0, 0.0], [100.0, 100.0, 100.0]], dtype=np.float64 + ) + + # If the result is correct the mmb is the box individuated by + # the opposite vertices [-100., -100., 0.] and [100., 100., 100.]. + # The tuple must be in the correct order. + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # The point of coordinates (0,0,50) is inside the cone, while (0,0,-50) is not. + # We test check_mbox, check_inside with these two points + point_to_check = np.array([[0, 0, 50], [0, 0, -50]], dtype=np.float64) + inside_mbox = sc.inside_mbox(point_to_check) + inside_shape = sc.inside_shapes(point_to_check) + + # Check if the points are inside the mbb + expected_inside_mbox = [True, False] + self.assertEqual( + inside_mbox[0], + expected_inside_mbox[0], + "The point (0,0,50) should be inside the minimal bounding box", + ) + self.assertEqual( + inside_mbox[1], + expected_inside_mbox[1], + "The point (0,0,-50) should be outside the minimal bounding box", + ) + + # Check if the points are inside the cone + expected_inside_shape = [True, False] + self.assertEqual( + inside_shape[0], + expected_inside_shape[0], + "The point (0,0,50) should be inside the cone", + ) + self.assertEqual( + inside_shape[1], + expected_inside_shape[1], + "The point (0,0,-50) should be outside the cone", + ) + + # Test generate_point_cloud method. + # The expected number of points is given by the volume of the cone divided by the voxel side + # to the third. + # The points should be inside the cone. + cone_height = np.linalg.norm(configuration["origin"] - configuration["apex"]) + volume = (np.pi * cone_height * configuration["radius"] ** 2) / 3.0 + self._check_points_inside(sc, volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) + + # Check rotation + expected_mbb = np.array( + [[-100.0, -100.0, 0.0], [100.0, 100.0, 100.0]], dtype=np.float64 + ) + cone.rotate(np.array([0.0, 1.0, 0.0]), np.pi / 2) + mbb = cone.find_mbb() + self.assertClose(mbb[0], expected_mbb[0]) + self.assertClose(mbb[1], expected_mbb[1]) + + wireframe = np.array(sc.generate_wireframe()) + self.assertEqual((3, 1, 30, 30), wireframe.shape) + self.assertTrue( + np.alltrue( + np.linalg.norm( + (wireframe[:, 0, :, 0] - cone.origin[..., np.newaxis])[ + np.array([0, 1]) + ], + axis=0, + ) + - radius + <= 1e-5 + ) + ) + for p in wireframe[:, 0, :, 0].T: + self.assertTrue( + 0 + <= np.absolute(p[2] - cone.origin[2]) + <= np.absolute(cone.apex[2] - cone.origin[2]) + ) + + # Create ShapeComposition object, add a sphere and a cylinder and then test + def test_shape_composition(self): + config_sphere = dict(radius=10.0, origin=np.array([0, 0, 0], dtype=np.float64)) + config_cylinder = dict( + top_center=np.array([0, 0, 0], dtype=np.float64), + radius=25.0, + origin=np.array([0, 0, -40], dtype=np.float64), + ) + + with self.assertRaises(RequirementError): + ShapesComposition( + dict( + shapes=[ + dict( + type="sphere", + radius=10.0, + center=np.array([0, 0, 0], dtype=np.float64), + ) + ], + labels=[["label1"], ["label2"]], + ) + ) + + with self.assertRaises(RequirementError): + ShapesComposition(dict(shapes=[], labels=[["label1"], ["label2"]])) + + # Create a ShapesComposition object + conf = dict(voxel_size=10, shapes=[], labels=[]) + sc = ShapesComposition(conf) + + # Build a sphere + sc.add_shape(Sphere(config_sphere), ["sphere"]) + + # Build a cylinder + sc.add_shape(Cylinder(config_cylinder), ["cylinder"]) + + # Check if shapes filtering by labels works + filtered_shape = sc.filter_by_labels(["sphere"]) + self.assertIsInstance(filtered_shape._shapes[0], Sphere) + + # Test the mininimal bounding box of the composition + # The expected mbb is [-25,-25,-40], [25,25,10] + mbb = sc.find_mbb() + expected_mbb = ( + np.array([-25.0, -25.0, -40.0], dtype=np.float64), + np.array([25.0, 25.0, 10.0], dtype=np.float64), + ) + + self.assertClose( + mbb[0], + expected_mbb[0], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + self.assertClose( + mbb[1], + expected_mbb[1], + "The minimal bounding box returned by find_mbb method is not the expected one", + ) + + # Test generate_point_cloud method for the composition of shapes. + # The expected number of points is given by the sum of the volumes the voxel side to the third + # The points should be inside the cone. + sphere_volume = 4.0 / 3.0 * np.pi * config_sphere["radius"] ** 3 + cylinder_volume = ( + np.pi + * np.linalg.norm(config_cylinder["origin"]) + * config_cylinder["radius"] ** 2 + ) + total_volume = sphere_volume + cylinder_volume + self._check_points_inside(sc, total_volume, conf["voxel_size"]) + self._check_translation(sc, expected_mbb) diff --git a/tests/test_morphologies.py b/tests/test_morphologies.py index 65a1aede..ea8fd223 100644 --- a/tests/test_morphologies.py +++ b/tests/test_morphologies.py @@ -384,6 +384,20 @@ def test_delete_point(self): self.assertClose(branch.radii, np.array([1, 2])) self.assertClose([0, 1], branch.labels) + def test_swap_axes(self): + points = np.arange(9).reshape(3, 3) + morpho = Morphology([Branch(points, np.array([0, 1, 2]))]) + morpho.swap_axes(0, 0) # no swap + self.assertAll(morpho.points == points) + morpho.swap_axes(0, 1) + self.assertAll( + morpho.points == np.vstack([points[:, 1], points[:, 0], points[:, 2]]).T + ) + with self.assertRaises(ValueError): + morpho.swap_axes(3, 1) + with self.assertRaises(ValueError): + morpho.swap_axes(1, -1) + class TestMorphologyLabels(NumpyTestCase, unittest.TestCase): def test_labels(self): diff --git a/tests/test_parsers.py b/tests/test_parsers.py new file mode 100644 index 00000000..2caf0335 --- /dev/null +++ b/tests/test_parsers.py @@ -0,0 +1,255 @@ +import ast +import pathlib +import unittest +from unittest.mock import patch + +from bsb.config.parsers import ( + ConfigurationParser, + ParsesReferences, + get_configuration_parser, +) +from bsb.exceptions import ConfigurationWarning, FileReferenceError, PluginError + + +def get_content(file: str): + return (pathlib.Path(__file__).parent / "data/configs" / file).read_text() + + +class RefParserMock(ParsesReferences, ConfigurationParser): + data_description = "txt" + data_extensions = ("txt",) + + def parse(self, content, path=None): + if isinstance(content, str): + content = ast.literal_eval(content) + return content, {"meta": path} + + def generate(self, tree, pretty=False): + # Should not be called. + pass + + +class RefParserMock2(ParsesReferences, ConfigurationParser): + data_description = "bla" + data_extensions = ("bla",) + + def parse(self, content, path=None): + if isinstance(content, str): + content = content.replace("<", "{") + content = content.replace(">", "}") + content = ast.literal_eval(content) + return content, {"meta": path} + + def generate(self, tree, pretty=False): + # Should not be called. + pass + + +class TestParsersBasics(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = RefParserMock() + + def test_get_parser(self): + self.assertRaises(PluginError, get_configuration_parser, "doesntexist") + + def test_parse_empty_doc(self): + tree, meta = self.parser.parse({}) + self.assertEqual({}, tree, "'{}' parse should produce empty dict") + + def assert_basics(self, tree, meta): + self.assertEqual(3, tree["list"][2], "Incorrectly parsed basic Txt") + self.assertEqual( + "just like that", + tree["nest me hard"]["oh yea"], + "Incorrectly parsed nested File", + ) + self.assertEqual( + "", str(tree["list"]) + ) + + def test_parse_basics(self): + # test from str + self.assert_basics(*self.parser.parse(get_content("basics.txt"))) + + # test from dict + content = { + "hello": "world", + "list": [1, 2, 3, "waddup"], + "nest me hard": {"oh yea": "just like that"}, + } + self.assert_basics(*self.parser.parse(content)) + + +class TestFileRef(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = RefParserMock() + + def test_indoc_reference(self): + content = ast.literal_eval(get_content("indoc_reference.txt")) + tree, meta = self.parser.parse(content) + self.assertNotIn("$ref", tree["refs"]["whats the"], "Ref key not removed") + self.assertEqual("key", tree["refs"]["whats the"]["secret"]) + self.assertEqual("is hard", tree["refs"]["whats the"]["nested secrets"]["vim"]) + self.assertEqual("convoluted", tree["refs"]["whats the"]["nested secrets"]["and"]) + # Checking str keys order. + self.assertEqual( + str(tree["refs"]["whats the"]["nested secrets"]), + "", + ) + self.assertEqual(tree["refs"]["whats the"], tree["refs"]["omitted_doc"]) + content["get"]["a"] = "secret" + with self.assertRaises(FileReferenceError, msg="Should raise 'ref not a dict'"): + tree, meta = self.parser.parse(content) + + @patch("bsb.config.parsers.get_configuration_parser_classes") + def test_far_references(self, get_content_mock): + # Override get_configuration_parser to manually register RefParserMock + get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2} + content = { + "refs": { + "whats the": {"$ref": "basics.txt#/nest me hard"}, + "and": {"$ref": "indoc_reference.txt#/refs/whats the"}, + "far": {"$ref": "far/targetme.bla#/this/key"}, + }, + "target": {"for": "another"}, + } + tree, meta = self.parser.parse( + content, + path=str( + (pathlib.Path(__file__).parent / "data" / "configs" / "interdoc_refs.txt") + ), + ) + self.assertIn("was", tree["refs"]["far"]) + self.assertEqual("in another folder", tree["refs"]["far"]["was"]) + self.assertIn("oh yea", tree["refs"]["whats the"]) + self.assertEqual("just like that", tree["refs"]["whats the"]["oh yea"]) + + @patch("bsb.config.parsers.get_configuration_parser_classes") + def test_double_ref(self, get_content_mock): + # Override get_configuration_parser to manually register RefParserMock + get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2} + tree, meta = self.parser.parse( + get_content("doubleref.txt"), + path=str( + (pathlib.Path(__file__).parent / "data" / "configs" / "doubleref.txt") + ), + ) + # Only the latest ref is included because the literal_eval keeps only the latest value + # for similar keys + self.assertNotIn("oh yea", tree["refs"]["whats the"]) + self.assertIn("for", tree["refs"]["whats the"]) + self.assertIn("another", tree["refs"]["whats the"]["for"]) + + @patch("bsb.config.parsers.get_configuration_parser_classes") + def test_ref_str(self, get_content_mock): + # Override get_configuration_parser to manually register RefParserMock + get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2} + tree, meta = self.parser.parse( + get_content("doubleref.txt"), + path=str( + (pathlib.Path(__file__).parent / "data" / "configs" / "doubleref.txt") + ), + ) + self.assertTrue(str(self.parser.references[0]).startswith("") + ) + + @patch("bsb.config.parsers.get_configuration_parser_classes") + def test_wrong_ref(self, get_content_mock): + # Override get_configuration_parser to manually register RefParserMock + get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2} + + content = {"refs": {"whats the": {"$ref": "basics.txt#/oooooooooooooo"}}} + with self.assertRaises(FileReferenceError, msg="ref should not exist"): + self.parser.parse( + content, + path=str( + (pathlib.Path(__file__).parent / "data" / "configs" / "wrong_ref.txt") + ), + ) + + +class TestFileImport(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.parser = RefParserMock() + + def test_indoc_import(self): + tree, meta = self.parser.parse(get_content("indoc_import.txt")) + self.assertEqual(["with", "importable"], list(tree["imp"].keys())) + self.assertEqual("are", tree["imp"]["importable"]["dicts"]["that"]) + + def test_indoc_import_list(self): + from bsb.config._parse_types import parsed_list + + content = ast.literal_eval(get_content("indoc_import.txt")) + content["arr"]["with"] = ["a", "b", ["a", "c"]] + tree, meta = self.parser.parse(content) + self.assertEqual(["with", "importable"], list(tree["imp"].keys())) + self.assertEqual("a", tree["imp"]["with"][0]) + self.assertEqual(parsed_list, type(tree["imp"]["with"][2]), "message") + + def test_indoc_import_value(self): + content = ast.literal_eval(get_content("indoc_import.txt")) + content["arr"]["with"] = "a" + tree, meta = self.parser.parse(content) + self.assertEqual(["with", "importable"], list(tree["imp"].keys())) + self.assertEqual("a", tree["imp"]["with"]) + + def test_import_merge(self): + tree, meta = self.parser.parse(get_content("indoc_import_merge.txt")) + self.assertEqual(2, len(tree["imp"].keys())) + self.assertIn("importable", tree["imp"]) + self.assertIn("with", tree["imp"]) + self.assertEqual( + ["importable", "with"], + list(tree["imp"].keys()), + "Imported keys should follow on original keys", + ) + self.assertEqual(4, tree["imp"]["importable"]["dicts"]["that"]) + self.assertEqual("eh", tree["imp"]["importable"]["dicts"]["even"]["nested"]) + self.assertEqual(["new", "list"], tree["imp"]["importable"]["dicts"]["with"]) + + def test_import_overwrite(self): + content = ast.literal_eval(get_content("indoc_import.txt")) + content["imp"]["importable"] = 10 + + with self.assertWarns(ConfigurationWarning) as warning: + tree, meta = self.parser.parse(content) + self.assertEqual(2, len(tree["imp"].keys())) + self.assertIn("importable", tree["imp"]) + self.assertIn("with", tree["imp"]) + self.assertEqual( + ["importable", "with"], + list(tree["imp"].keys()), + "Imported keys should follow on original keys", + ) + self.assertEqual(10, tree["imp"]["importable"]) + + @patch("bsb.config.parsers.get_configuration_parser_classes") + def test_outdoc_import_merge(self, get_content_mock): + # Override get_configuration_parser to manually register RefParserMock + get_content_mock.return_value = {"txt": RefParserMock, "bla": RefParserMock2} + + file = "outdoc_import_merge.txt" + tree, meta = self.parser.parse( + get_content(file), path=pathlib.Path(__file__).parent / "data/configs" / file + ) + + expected = { + "with": {}, + "importable": { + "dicts": { + "that": "are", + "even": {"nested": "eh"}, + "with": ["new", "list"], + }, + "diff": "added", + }, + } + self.assertTrue(expected == tree["imp"]) diff --git a/tests/test_placement.py b/tests/test_placement.py index 4b85e6e6..7c29c296 100644 --- a/tests/test_placement.py +++ b/tests/test_placement.py @@ -331,11 +331,12 @@ def test_particle_vd(self): voxels=network.partitions.test_part.vs, ) self.assertEqual(4, len(counts), "should have vector of counts per voxel") - self.assertTrue(np.allclose([78, 16, 8, 27], counts, atol=1), "densities incorr") + # test rounded down values + self.assertTrue(np.allclose([78, 15, 7, 26], counts, atol=1), "densities incorr") network.compile(clear=True) ps = network.get_placement_set("test_cell") - self.assertGreater(len(ps), 90) - self.assertLess(len(ps), 130) + self.assertGreater(len(ps), 125) # rounded down values -1 + self.assertLess(len(ps), 132) # rounded up values + 1 def _config_packing_fact(self): return Configuration.default( diff --git a/tests/test_util.py b/tests/test_util.py index 7dae6430..bdd7f995 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -10,7 +10,7 @@ from scipy.spatial.transform import Rotation from bsb import FileDependency, NeuroMorphoScheme, Scaffold -from bsb._util import rotation_matrix_from_vectors +from bsb._util import assert_samelen, rotation_matrix_from_vectors class TestNetworkUtil( @@ -93,3 +93,12 @@ def test_nm_scheme_down(self): file.get_meta() finally: NeuroMorphoScheme._nm_url = url + + +class TestAssertSameLength(unittest.TestCase): + def test_same_length(self): + assert_samelen([1, 2, 3], [4, 5, 6]) + with self.assertRaises(AssertionError): + assert_samelen([1, 2], [2]) + assert_samelen([[1, 2]], [3]) + assert_samelen([], [])