diff --git a/kwave/utils/kwave_array.py b/kwave/utils/kwave_array.py index 2a6ae7ab..5c7b1680 100644 --- a/kwave/utils/kwave_array.py +++ b/kwave/utils/kwave_array.py @@ -7,6 +7,7 @@ import numpy as np from numpy import arcsin, pi, cos, size, array from numpy.linalg import linalg +import pytest from kwave.data import Vector from kwave.kgrid import kWaveGrid @@ -17,6 +18,176 @@ from kwave.utils.matlab import matlab_assign, matlab_mask, matlab_find +class Group: + def __init__(self, id, name=None): + self.id = id + # self.name = name + self.elements = [] + + def add_element(self, element): + self.elements.append(element) + element.group = self + + @property + def num_elements(self): + return len(self.elements) + + def __len__(self): + return self.num_elements + + def __getitem__(self, index): + return self.elements[index] + + +@dataclass +class BaseElement: + group_id: int = 0 + dim: int = 0 + active: bool = True + element_number: int = None + + def get_integration_points(self, m_integration): + raise NotImplementedError("Subclasses must implement this method") + + @property + def measure(self): + raise NotImplementedError("Subclasses must implement this method") + + +class AnnulusElement(BaseElement): + def __init__(self, position, radius_of_curvature, inner_diameter, outer_diameter, focus_position, **kwargs): + super().__init__(**kwargs) + self.dim = 2 + self.position = np.array(position) + self.radius_of_curvature = radius_of_curvature + self.inner_diameter = inner_diameter + self.outer_diameter = outer_diameter + self.focus_position = np.array(focus_position) + + @property + def measure(self): + varphi_min = arcsin(self.inner_diameter / (2 * self.radius_of_curvature)) + varphi_max = arcsin(self.outer_diameter / (2 * self.radius_of_curvature)) + + return 2 * pi * self.radius_of_curvature**2 * (1 - cos(varphi_max)) - 2 * pi * self.radius_of_curvature**2 * (1 - cos(varphi_min)) + + def get_integration_points(self, m_integration): + return make_cart_spherical_segment( + self.position, self.radius_of_curvature, self.inner_diameter, self.outer_diameter, self.focus_position, m_integration + ) + + +class BowlElement(BaseElement): + def __init__(self, position, radius_of_curvature, diameter, focus_position, **kwargs): + super().__init__(**kwargs) + self.dim = 2 # TODO: check if this should be 3 + self.position = np.array(position) + self.radius_of_curvature = radius_of_curvature + self.diameter = diameter + self.focus_position = np.array(focus_position) + + def get_integration_points(self, m_integration): + return make_cart_bowl(self.position, self.radius_of_curvature, self.diameter, self.focus_position, m_integration) + + @property + def measure(self): + varphi_max = arcsin(self.diameter / (2 * self.radius_of_curvature)) + return 2 * pi * self.radius_of_curvature**2 * (1 - cos(varphi_max)) + + +class RectElement(BaseElement): + def __init__(self, position, length, width, theta, **kwargs): + super().__init__(**kwargs) + self.dim = 2 if len(position) == 2 else 3 + self.position = np.array(position) + self.length = length + self.width = width + self.orientation = np.array(theta) if len(position) == 3 else theta + + def get_integration_points(self, m_integration): + return make_cart_rect(self.position, self.length, self.width, self.orientation, m_integration) + + @property + def measure(self): + return self.length * self.width + + +class ArcElement(BaseElement): + def __init__(self, position, radius_of_curvature, diameter, focus_position, **kwargs): + super().__init__(**kwargs) + self.position = np.array(position) + self.dim = 1 + self.radius_of_curvature = radius_of_curvature + self.diameter = diameter + self.focus_position = np.array(focus_position) + + def get_integration_points(self, m_integration): + return make_cart_arc(self.position, self.radius_of_curvature, self.diameter, self.focus_position, m_integration) + + @property + def measure(self): + varphi_max = arcsin(self.diameter / (2 * self.radius_of_curvature)) + return 2 * self.radius_of_curvature * varphi_max + + +class DiscElement(BaseElement): + def __init__(self, position, diameter, focus_position=None, **kwargs): + super().__init__(**kwargs) + if (dim := len(position)) not in [2, 3]: + raise ValueError(f"Input position for disc element must be specified as a 2 (2D) or 3 (3D) element array. Got {dim}") + else: + self.dim = dim + self.position = np.array(position) + self.diameter = diameter + if self.dim == 3 and focus_position is None: + raise ValueError("Focus position must be provided for 3D disc element.") + + self.focus_position = np.array(focus_position) + + def get_integration_points(self, m_integration): + return make_cart_disc(self.position, self.diameter, self.focus_position, m_integration) + + @property + def measure(self): + return pi * (self.diameter / 2) ** 2 + + +class LineElement(BaseElement): + def __init__(self, start_point, end_point, **kwargs): + super().__init__(**kwargs) + self.start_point = np.array(start_point) + self.end_point = np.array(end_point) + self.dim = 1 + + def get_integration_points(self, m_integration): + raise NotImplementedError("Integration points for line elements are not yet implemented.") + + @property + def measure(self): + return linalg.norm(self.end_point - self.start_point) + + +class CustomElement(BaseElement): + def __init__(self, integration_points, measure, dim, label, **kwargs): + super().__init__(**kwargs) + self.integration_points = integration_points + self.dim = dim + self.label = label + self._user_generated_measure = measure + + def get_integration_points(self): + return self.integration_points + + @property + def measure(self): + return self._user_generated_measure + + +@dataclass +class ElementArray: + group_type: str + + @dataclass class Element: group_id: int @@ -112,6 +283,83 @@ def __init__( self.num_arc_plot_points = 100 self.element_plot_colour = np.array([0, 158, 194], dtype=float) / 255 + def add_element(self, **kwargs): + element_type = self._infer_element_type(**kwargs) + element = self._create_element(element_type, **kwargs) + element = self._convert_to_legacy_element(element) + self.elements.append(element) + self.number_elements += 1 + return element + + def _infer_legacy_element_type(self, element: BaseElement) -> str: + accepted_types = {"annulus", "bowl", "rect", "arc", "disc", "line", "custom"} + suffix = "Element" + + # Retrieve the class name of the element and remove the suffix "Element" + element_type = element.__class__.__name__ + if element_type.endswith(suffix): + element_name = element_type[: -len(suffix)].lower() + else: + raise ValueError("Element class name does not follow expected format.") + + # Validate that the derived name is an accepted type + if element_name not in accepted_types: + raise ValueError(f"Unknown element type: {element_name}") + + return element_name + + def _convert_to_legacy_element(self, element: BaseElement) -> Element: + """ + Convert a new-style element to the legacy format for compatibility. + """ + element_dict = element.__dict__ + legacy_type = self._infer_legacy_element_type(element) + element_dict["type"] = legacy_type + element_dict["measure"] = element.measure + if legacy_type == "custom": + element_dict.pop("_user_generated_measure") + + legacy_element = Element(**element_dict) + return legacy_element + + def add_array(self, **kwargs): + # element_type = self._infer_element_type(**kwargs) + pass + + def _create_element(self, element_type, **kwargs): + if element_type == "annulus": + return AnnulusElement(**kwargs) + elif element_type == "bowl": + return BowlElement(**kwargs) + elif element_type == "rect": + return RectElement(**kwargs) + elif element_type == "arc": + return ArcElement(**kwargs) + elif element_type == "disc": + return DiscElement(**kwargs) + elif element_type == "line": + return LineElement(**kwargs) + elif element_type == "custom": + return CustomElement(**kwargs) + else: + raise ValueError(f"Unknown element type: {element_type}") + + def _infer_element_type(self, **kwargs): + if "inner_diameter" in kwargs and "outer_diameter" in kwargs and "focus_position" in kwargs: + return "annulus" + elif "radius_of_curvature" in kwargs and "diameter" in kwargs: + return "arc" if len(kwargs.get("position")) == 2 else "bowl" + elif "length" in kwargs and "width" in kwargs: + return "rect" + elif "diameter" in kwargs: + return "disc" + elif "start_point" in kwargs and "end_point" in kwargs: + return "line" + elif "integration_points" in kwargs: + return "custom" + else: + raise ValueError("Could not infer element type from provided arguments.") + def add_annular_array(self, position, radius, diameters, focus_pos): assert isinstance(position, (list, tuple)), "'position' must be list or tuple" assert isinstance(radius, (int, float)), "'radius' must be an integer or float" @@ -170,27 +418,14 @@ def add_annular_element(self, position, radius, diameters, focus_pos): if self.dim != 3: raise ValueError(f"3D annular array cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - varphi_min = arcsin(diameters[0] / (2 * radius)) - varphi_max = arcsin(diameters[1] / (2 * radius)) - - area = 2 * pi * radius**2 * (1 - cos(varphi_max)) - 2 * pi * radius**2 * (1 - cos(varphi_min)) - - self.elements.append( - Element( - group_id=0, - type="annulus", - dim=2, - position=array(position), - radius_of_curvature=radius, - inner_diameter=diameters[0], - outer_diameter=diameters[1], - focus_position=array(focus_pos), - active=True, - measure=area, - ) - ) + kwargs = { + "position": position, + "radius_of_curvature": radius, + "inner_diameter": diameters[0], + "outer_diameter": diameters[1], + "focus_position": focus_pos, + } + self.add_element(**kwargs) def add_bowl_element(self, position, radius, diameter, focus_pos): assert isinstance(position, (list, tuple)), "'position' must be list or tuple" @@ -206,30 +441,13 @@ def add_bowl_element(self, position, radius, diameter, focus_pos): if self.dim != 3: raise ValueError(f"3D bowl element cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - varphi_max = arcsin(diameter / (2 * radius)) + kwargs = {"position": position, "radius_of_curvature": radius, "diameter": diameter, "focus_position": focus_pos} + self.add_element(**kwargs) - area = 2 * pi * radius**2 * (1 - cos(varphi_max)) - - self.elements.append( - Element( - group_id=0, - type="bowl", - dim=2, - position=array(position), - radius_of_curvature=radius, - diameter=diameter, - focus_position=array(focus_pos), - active=True, - measure=area, - ) - ) - - def add_custom_element(self, integration_points, measure, element_dim, label): + def add_custom_element(self, integration_points, measure, dim, label): assert isinstance(integration_points, (np.ndarray)), "'integration_points' must be a numpy array" assert isinstance(measure, (int, float)), "'measure' must be an integer or float" - assert isinstance(element_dim, (int)) and element_dim in [1, 2, 3], "'element_dim' must be an integer and either 1, 2 or 3" + assert isinstance(dim, (int)) and dim in [1, 2, 3], "'dim' must be an integer and either 1, 2 or 3" assert isinstance(label, (str)), "'label' must be a string" # check the dimensionality of the integration points @@ -246,13 +464,8 @@ def add_custom_element(self, integration_points, measure, element_dim, label): if self.dim != input_dim: raise ValueError(f"{input_dim}D custom element cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - self.elements.append( - Element( - group_id=0, type="custom", dim=element_dim, label=label, integration_points=integration_points, active=True, measure=measure - ) - ) + kwargs = {"integration_points": integration_points, "measure": measure, "dim": dim, "label": label} + self.add_element(**kwargs) def add_rect_element(self, position, Lx, Ly, theta): assert isinstance(position, (list, tuple)), "'position' must be a list or tuple" @@ -275,23 +488,8 @@ def add_rect_element(self, position, Lx, Ly, theta): if self.dim != coord_dim: raise ValueError(f"{coord_dim}D rectangular element cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - area = Lx * Ly - - self.elements.append( - Element( - group_id=0, - type="rect", - dim=2, - position=array(position), - length=Lx, - width=Ly, - orientation=array(theta) if coord_dim == 3 else theta, - active=True, - measure=area, - ) - ) + kwargs = {"position": position, "length": Lx, "width": Ly, "theta": theta} + self.add_element(**kwargs) def add_arc_element(self, position, radius, diameter, focus_pos): assert isinstance(position, (list, tuple, Vector)), "'position' must be list, tuple or Vector" @@ -307,25 +505,8 @@ def add_arc_element(self, position, radius, diameter, focus_pos): if self.dim != 2: raise ValueError(f"2D arc element cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - varphi_max = arcsin(diameter / (2 * radius)) - - length = 2 * radius * varphi_max - - self.elements.append( - Element( - group_id=0, - type="arc", - dim=1, - position=array(position), - radius_of_curvature=radius, - diameter=diameter, - focus_position=array(focus_pos), - active=True, - measure=length, - ) - ) + kwargs = {"position": position, "radius_of_curvature": radius, "diameter": diameter, "focus_position": focus_pos} + self.add_element(**kwargs) def add_disc_element(self, position, diameter, focus_pos=None): assert isinstance(position, (list, tuple)), "'position' must be a list or tuple" @@ -347,22 +528,8 @@ def add_disc_element(self, position, diameter, focus_pos=None): if self.dim != coord_dim: raise ValueError(f"{coord_dim}D disc element cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - area = pi * (diameter / 2) ** 2 - - self.elements.append( - Element( - group_id=0, - type="disc", - dim=2, - position=array(position), - diameter=diameter, - focus_position=array(focus_pos), - active=True, - measure=area, - ) - ) + kwargs = {"position": position, "diameter": diameter, "focus_position": focus_pos} + self.add_element(**kwargs) def remove_element(self, element_num): if element_num > self.number_elements: @@ -390,15 +557,8 @@ def add_line_element(self, start_point, end_point): if self.dim != input_dim: raise ValueError(f"{input_dim}D line element cannot be added to an array with {self.dim}D elements.") - self.number_elements += 1 - - line_length = linalg.norm(array(end_point) - array(start_point)) - - self.elements.append( - Element( - group_id=0, type="line", dim=1, start_point=array(start_point), end_point=array(end_point), active=True, measure=line_length - ) - ) + kwargs = {"start_point": start_point, "end_point": end_point} + self.add_element(**kwargs) def get_element_grid_weights(self, kgrid, element_num): return self.get_off_grid_points(kgrid, element_num, False) @@ -586,7 +746,7 @@ def get_off_grid_points(self, kgrid, element_num, mask_only): ) # keep points in the positive y domain - grid_weights = grid_weights[:, kgrid.Ny:] + grid_weights = grid_weights[:, kgrid.Ny :] else: # remove integration points which are outside grid @@ -877,3 +1037,7 @@ def off_grid_points( if display_wait_bar and (point_ind % wait_bar_update_freq == 0): tqdm.update(wait_bar_update_freq) return mask + + +if __name__ == "__main__": + pytest.main(["-v", __file__]) diff --git a/tests/matlab_test_data_collectors/python_testers/kWaveArray_test.py b/tests/matlab_test_data_collectors/python_testers/kWaveArray_test.py index 18d4eb3a..e9aa4142 100644 --- a/tests/matlab_test_data_collectors/python_testers/kWaveArray_test.py +++ b/tests/matlab_test_data_collectors/python_testers/kWaveArray_test.py @@ -6,10 +6,57 @@ from kwave.data import Vector from kwave.kgrid import kWaveGrid -from kwave.utils.kwave_array import kWaveArray +from kwave.utils.kwave_array import ( + ArcElement, + BowlElement, + CustomElement, + DiscElement, + LineElement, + RectElement, + kWaveArray, + AnnulusElement, +) from tests.matlab_test_data_collectors.python_testers.utils.check_equality import check_kwave_array_equality from tests.matlab_test_data_collectors.python_testers.utils.record_reader import TestRecordReader +type_to_class = { + "annulus": "AnnulusElement", + "bowl": "BowlElement", + "custom": "CustomElement", + "disc": "DiscElement", + "line": "LineElement", + "rect": "RectElement", + "arc": "ArcElement", +} + + +def compare_elements(expected_element, python_element): + # itterate through the properties of the element and compare them + # be sure to include properties of annulus element + for key, expected_value in expected_element.items(): + if key == "type": + assert type_to_class[expected_value] == python_element.__class__.__name__ + continue + actual_value = getattr(python_element, key) + if key == "integration_points": + actual_value = actual_value.tolist() + expected_value = expected_value.tolist() + if isinstance(expected_value, np.ndarray) and isinstance(actual_value, np.ndarray): + if expected_value.dtype == object or actual_value.dtype == object: + pass + elif isinstance(expected_value, str) and isinstance(actual_value, str): + actual_value == expected_value + else: + assert np.all(np.isclose(actual_value, expected_value)) + + +def fix_dim_bug(expected_value): + # TODO: position is 3D but dim is somehow 2D. Is this a bug in k-wave? + for element in expected_value["elements"]: + if element["type"] == "rect" or element["type"] == "disc": + element["dim"] = 3 + return expected_value + def test_kwave_array(): test_record_path = os.path.join(Path(__file__).parent, "collectedValues/kWaveArray.mat") @@ -17,6 +64,9 @@ def test_kwave_array(): kwave_array = kWaveArray() + # TODO: test elements individually + # create an annular element in kWaveArray and compare only the element created and it's properties to elements create in python + # Useful for checking if the defaults are set correctly check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) reader.increment() @@ -39,44 +89,66 @@ def test_kwave_array(): kwave_array.add_annular_element([0, 0, 0], 5, [0.001, 0.03], [1, 5, -3]) check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + # compare with annular element created in python + expected_annulus_element = reader.expected_value_of("kwave_array")["elements"][-1] + annulus_element = AnnulusElement([0, 0, 0], 5, 0.001, 0.03, [1, 5, -3]) + compare_elements(expected_annulus_element, annulus_element) + reader.increment() kwave_array.add_bowl_element([0, 0, 0], 5, 4.3, [1, 5, -3]) check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) - reader.increment() + expected_bowl_element = reader.expected_value_of("kwave_array")["elements"][-1] + bowl_element = BowlElement([0, 0, 0], 5, 4.3, [1, 5, -3]) + compare_elements(expected_bowl_element, bowl_element) + reader.increment() + integration_points = np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3], [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.float32) kwave_array.add_custom_element( - integration_points=np.array( - [[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3], [0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=np.float32 - ), + integration_points=integration_points, measure=9, - element_dim=2, + dim=2, label="custom_3d", ) check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + expected_custom_element = reader.expected_value_of("kwave_array")["elements"][-1] + custom_element = CustomElement(integration_points, 9, 2, "custom_3d") + compare_elements(expected_custom_element, custom_element) + with pytest.raises(ValueError): kwave_array.add_custom_element( - integration_points=np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3]], dtype=np.float32), + integration_points=np.array([[1, 1, 1, 2, 2, 2, 3, 3, 3], [1, 2, 3, 1, 2, 3, 1, 2, 3]]), measure=9, - element_dim=2, + dim=2, label="custom_3d", ) reader.increment() kwave_array.add_rect_element([12, -8, 0.3], 3, 4, [2, 4, 5]) - check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + check_kwave_array_equality(kwave_array, fix_dim_bug(reader.expected_value_of("kwave_array"))) + rect_element = RectElement([12, -8, 0.3], 3, 4, [2, 4, 5]) + expected_rect_element = fix_dim_bug(reader.expected_value_of("kwave_array"))["elements"][-1] + compare_elements(expected_rect_element, rect_element) reader.increment() kwave_array.add_disc_element([0, 0.3, 12], 5, [1, 5, 8]) - check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + check_kwave_array_equality(kwave_array, fix_dim_bug(reader.expected_value_of("kwave_array"))) + + expected_disc_element = reader.expected_value_of("kwave_array")["elements"][-1] + expected_disc_element["dim"] = 3 + disc_element = DiscElement([0, 0.3, 12], 5, [1, 5, 8]) + compare_elements(expected_disc_element, disc_element) reader.increment() # test list input kwave_array = kWaveArray() kwave_array.add_arc_element([0, 0.3], 5, 4.3, [1, 5]) check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + arc_element = ArcElement([0, 0.3], 5, 4.3, [1, 5]) + expected_arc_element = reader.expected_value_of("kwave_array")["elements"] + compare_elements(expected_arc_element, arc_element) # test tuple input kwave_array = kWaveArray() kwave_array.add_arc_element((0, 0.3), 5, 4.3, (1, 5)) @@ -89,6 +161,9 @@ def test_kwave_array(): kwave_array.add_disc_element([0, 0.3], 5) check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + disc_element = DiscElement([0, 0.3], 5) + expected_disc_element = reader.expected_value_of("kwave_array")["elements"][-1] + compare_elements(expected_disc_element, disc_element) reader.increment() kwave_array.add_custom_element( @@ -120,6 +195,9 @@ def test_kwave_array(): kwave_array = kWaveArray() kwave_array.add_line_element([0, 3], [5, 2]) check_kwave_array_equality(kwave_array, reader.expected_value_of("kwave_array")) + line_element = LineElement([0, 3], [5, 2]) + expected_line_element = reader.expected_value_of("kwave_array")["elements"] + compare_elements(expected_line_element, line_element) reader.increment() kwave_array = kWaveArray() @@ -178,3 +256,7 @@ def test_kwave_array(): assert kwave_array.dim == 2 assert np.allclose(kwave_array.get_array_grid_weights(kgrid).shape, reader.expected_value_of("grid_weights").shape) assert np.allclose(kwave_array.get_array_grid_weights(kgrid), np.squeeze(reader.expected_value_of("grid_weights"))) + + +if __name__ == "__main__": + pytest.main(["-v", __file__])