From 5ba8cf354687286615211c9da9d5e7f42420eeaf Mon Sep 17 00:00:00 2001 From: Igor Tatarnikov <61896994+IgorTatarnikov@users.noreply.github.com> Date: Thu, 1 Aug 2024 11:25:28 +0100 Subject: [PATCH] Call BigStitcher by command line and read the results (#7) * Add macro that calls big stitcher * Add function to call BigStitcher via the command line * Add functions to read the output of BigStitcher * Add stitch button to napari widget * Added tests for big_stitcher_bridge * Added tests for file_utils * Fixed tests to account for new defaults in stitch function * Added tests for new functions in image_mosaic * Added tests for new ImageMosaic and StitchingWidget functions * Updated docstring from big_stitcher_bridge * Updated docs for file_utils * Added docstrings for ImageMosaic and StitchingWidget * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP apply suggestions from code review * More code review changes * Move test constants out of test_image_mosaic to conftest * Moved constants out of test_file_utils * Moved constants out of test_big_stitcher_bridge * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make sure all ImageMosaic objects are explicitly cleaned up * Update error for wrong imagej path * WIP adding docstrings to test_stitching_widgets * Added docstrings to test_stitching_widget * Added comments for conftest.py * Added docstrings for test_big_stitcher_bridge * Added tests for test_image_mosaic * Add docstring to test_file_utils --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- MANIFEST.in | 1 + brainglobe_stitch/big_stitcher_bridge.py | 67 +++++ brainglobe_stitch/bigstitcher_macro.ijm | 57 ++++ brainglobe_stitch/file_utils.py | 153 ++++++++++- brainglobe_stitch/image_mosaic.py | 93 ++++++- brainglobe_stitch/stitching_widget.py | 158 ++++++++++- brainglobe_stitch/tile.py | 4 +- tests/test_unit/conftest.py | 140 ++++++++++ tests/test_unit/test_big_stitcher_bridge.py | 99 +++++++ tests/test_unit/test_file_utils.py | 227 ++++++++++++++++ tests/test_unit/test_image_mosaic.py | 142 ++++++---- tests/test_unit/test_stitching_widget.py | 275 +++++++++++++++++--- tests/test_unit/test_tile.py | 2 +- 13 files changed, 1301 insertions(+), 117 deletions(-) create mode 100644 brainglobe_stitch/big_stitcher_bridge.py create mode 100644 brainglobe_stitch/bigstitcher_macro.ijm create mode 100644 tests/test_unit/test_big_stitcher_bridge.py create mode 100644 tests/test_unit/test_file_utils.py diff --git a/MANIFEST.in b/MANIFEST.in index 9175ede..0cf896e 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ include LICENSE include README.md include brainglobe_stitch/napari.yaml +include brainglobe_stitch/bigstitcher_macro.ijm exclude .pre-commit-config.yaml recursive-include brainglobe_stitch *.py diff --git a/brainglobe_stitch/big_stitcher_bridge.py b/brainglobe_stitch/big_stitcher_bridge.py new file mode 100644 index 0000000..7651e85 --- /dev/null +++ b/brainglobe_stitch/big_stitcher_bridge.py @@ -0,0 +1,67 @@ +import subprocess +from pathlib import Path +from platform import system + + +def run_big_stitcher( + imagej_path: Path, + xml_path: Path, + tile_config_path: Path, + all_channels: bool = False, + selected_channel: int = 488, + downsample_x: int = 4, + downsample_y: int = 4, + downsample_z: int = 1, +) -> subprocess.CompletedProcess: + """ + Run the BigStitcher ImageJ macro. Output is captured and returned as part + of the subprocess.CompletedProcess. + + Parameters + ---------- + imagej_path : Path + The path to the ImageJ executable. + xml_path : Path + The path to the BigDataViewer XML file. + tile_config_path : Path + The path to the BigStitcher tile configuration file. + all_channels : bool, optional + Whether to stitch based on all channels (default False). + selected_channel : int, optional + The channel on which to base the stitching (default 488). + downsample_x : int, optional + The downsample factor in the x-dimension for the stitching (default 4). + downsample_y : int, optional + The downsample factor in the y-dimension for the stitching (default 4). + downsample_z : int, optional + The downsample factor in the z-dimension for the stitching (default 1). + + Returns + ------- + subprocess.CompletedProcess + The result of the subprocess run. + + Raises + ------ + subprocess.CalledProcessError + If the subprocess returns a non-zero exit status. + """ + stitch_macro_path = ( + Path(__file__).resolve().parent / "bigstitcher_macro.ijm" + ) + + if system().startswith("Darwin"): + imagej_path = imagej_path / "Contents/MacOS/ImageJ-macosx" + + command = ( + f"{imagej_path} --ij2 " + f"--headless -macro {stitch_macro_path} " + f'"{xml_path} {tile_config_path} {int(all_channels)} ' + f'{selected_channel} {downsample_x} {downsample_y} {downsample_z}"' + ) + + result = subprocess.run( + command, capture_output=True, text=True, check=True, shell=True + ) + + return result diff --git a/brainglobe_stitch/bigstitcher_macro.ijm b/brainglobe_stitch/bigstitcher_macro.ijm new file mode 100644 index 0000000..124f902 --- /dev/null +++ b/brainglobe_stitch/bigstitcher_macro.ijm @@ -0,0 +1,57 @@ +args = getArgument(); +args = split(args, " "); +xmlPath = args[0]; +tilePath = args[1]; +allChannels = args[2]; +selectedChannel = args[3]; +downSampleX = args[4]; +downSampleY = args[5]; +downSampleZ = args[6]; + +print("Stitching " + xmlPath); +print("Loading TileConfiguration from " + tilePath) +run( + "Load TileConfiguration from File...", + "browse=" + xmlPath + + " select=" + xmlPath + + " tileconfiguration=" + tilePath + + " use_pixel_units keep_metadata_rotation" +); + +print("Calculating pairwise shifts"); +if (allChannels == 1) { + run( + "Calculate pairwise shifts ...", + "select=" + xmlPath + + " process_angle=[All angles] process_channel=[All channels] process_illumination=[All illuminations]" + + " process_tile=[All tiles] process_timepoint=[All Timepoints]" + + " method=[Phase Correlation] channels=[Average Channels]" + + " downsample_in_x=" + downSampleX + + " downsample_in_y=" + downSampleY + + " downsample_in_z=" + downSampleZ + ); +} else { + run( + "Calculate pairwise shifts ...", + "select=" + xmlPath + + " process_angle=[All angles] process_channel=[All channels] process_illumination=[All illuminations]" + + " process_tile=[All tiles] process_timepoint=[All Timepoints]" + + " method=[Phase Correlation] channels=[use Channel " + selectedChannel + " nm]" + + " downsample_in_x=" + downSampleX + + " downsample_in_y=" + downSampleY + + " downsample_in_z=" + downSampleZ + ); +} + +print("Optimizing globally and applying shifts"); +run( + "Optimize globally and apply shifts ...", + "select=" + xmlPath + + " process_angle=[All angles] process_channel=[All channels] process_illumination=[All illuminations]" + + " process_tile=[All tiles] process_timepoint=[All Timepoints] relative=2.500 absolute=3.500" + + " global_optimization_strategy=[Two-Round using Metadata to align unconnected Tiles and iterative dropping of bad links]" + + " fix_group_0-0," +); + +print("Done"); +eval("script", "System.exit(0);"); diff --git a/brainglobe_stitch/file_utils.py b/brainglobe_stitch/file_utils.py index 8aa95c0..1f7d33f 100644 --- a/brainglobe_stitch/file_utils.py +++ b/brainglobe_stitch/file_utils.py @@ -34,16 +34,16 @@ def create_pyramid_bdv_h5( Parameters ---------- - input_file: Path + input_file : Path The path to the input HDF5 file. - yield_progress: bool, optional + yield_progress : bool, optional Whether to yield progress. If True, the function will yield the progress as a percentage. - resolutions_array: npt.NDArray, optional + resolutions_array : npt.NDArray, optional The downsampling factors to use for each resolution level. This is a 2D array where each row represents a resolution level and the columns represent the downsampling factors for x, y, and z. - subdivisions_array: npt.NDArray, optional + subdivisions_array : npt.NDArray, optional The size of the blocks at each resolution level. This is a 2D array where each row represents a resolution level and the columns represent the size of the blocks for x, y, and z. @@ -113,7 +113,7 @@ def parse_mesospim_metadata( Parameters ---------- - meta_file_name: Path + meta_file_name : Path The path to the h5_meta.txt file. Returns @@ -163,7 +163,7 @@ def check_mesospim_directory( Parameters ---------- - mesospim_directory: Path + mesospim_directory : Path The path to the mesoSPIM directory. Returns @@ -203,9 +203,9 @@ def get_slice_attributes( Parameters ---------- - xml_path: Path + xml_path : Path The path to the XML file. - tile_names: List[str] + tile_names : List[str] The names of the tiles. Returns @@ -226,3 +226,140 @@ def get_slice_attributes( slice_attributes[name] = sub_dict return slice_attributes + + +def get_big_stitcher_transforms(xml_path: Path) -> npt.NDArray: + """ + Get the translations for each tile from a Big Data Viewer XML file. + The translations are calculated by BigStitcher. + + Parameters + ---------- + xml_path : Path + The path to the Big Data Viewer XML file. + + Returns + ------- + npt.NDArray + A numpy array of shape (num_tiles, num_dim) with the translations. + Each row corresponds to a tile and each column to a dimension. + """ + tree = ET.parse(xml_path) + root = tree.getroot() + + stitch_transforms = safe_find_all( + root, ".//ViewTransform/[Name='Stitching Transform']/affine" + ) + + # Stitching Transforms are there if aligning to grid is done manually + # Translation from Tile Configurations are there if aligned automatically + grid_transforms = safe_find_all( + root, + ".//ViewTransform/[Name='Translation from Tile Configuration']/affine", + ) + if len(grid_transforms) == 0: + grid_transforms = safe_find_all( + root, + ".//ViewTransform/[Name='Translation to Regular Grid']/affine", + ) + + z_scale_str = safe_find( + root, ".//ViewTransform/[Name='calibration']/affine" + ) + + if not z_scale_str.text: + raise ValueError("No z scale found in XML") + + z_scale = float(z_scale_str.text.split()[-2]) + + deltas = np.ones((len(stitch_transforms), 3)) + grids = np.ones((len(grid_transforms), 3)) + for i in range(len(stitch_transforms)): + delta_nums_text = stitch_transforms[i].text + grid_nums_text = grid_transforms[i].text + + if not delta_nums_text or not grid_nums_text: + raise ValueError("No translation values found in XML") + + delta_nums = delta_nums_text.split() + grid_nums = grid_nums_text.split() + + # Extract the translation values from the transform. + # Swap the order of the axis (x,y,z) to (z,y,x). + # The input values are a flattened 4x4 matrix where + # the translation values in the last column. + deltas[i] = np.array(delta_nums[11:2:-4]) + grids[i] = np.array(grid_nums[11:2:-4]) + + # Divide the z translation by the z scale + deltas[:, 0] /= z_scale + grids[:, 0] /= z_scale + + # Round the translations to the nearest integer + grids = grids.round().astype(np.int32) + deltas = deltas.round().astype(np.int32) + + # Normalise the grid transforms by subtracting the minimum value + norm_grids = grids - grids.min(axis=0) + # Calculate the maximum delta (from BigStitcher) for each dimension + max_delta = np.absolute(deltas).max(axis=0) + + # Calculate the start and end coordinates for each tile such that the + # first tile is at 0,0,0 and provide enough padding to account for the + # transforms from BigStitcher + translations = norm_grids + deltas + max_delta + + return translations + + +def safe_find_all(root: ET.Element, query: str) -> List[ET.Element]: + """ + Find all elements matching a query in an ElementTree root. If no + elements are found, return an empty list. + + Parameters + ---------- + root : ET.Element + The root of the ElementTree. + query : str + The query to search for. + + Returns + ------- + List[ET.Element] + A list of elements matching the query. + """ + elements = root.findall(query) + if elements is None: + return [] + + return elements + + +def safe_find(root: ET.Element, query: str) -> ET.Element: + """ + Find the first element matching a query in an ElementTree root. + Raise a ValueError if no element found. + + Parameters + ---------- + root : ET.Element + The root of the ElementTree. + query : str + The query to search for. + + Returns + ------- + ET.Element + The element matching the query or None. + + Raises + ------ + ValueError + If no element is found. + """ + element = root.find(query) + if element is None or element.text is None: + raise ValueError(f"No element found for query {query}") + + return element diff --git a/brainglobe_stitch/image_mosaic.py b/brainglobe_stitch/image_mosaic.py index a3d470f..3ac6193 100644 --- a/brainglobe_stitch/image_mosaic.py +++ b/brainglobe_stitch/image_mosaic.py @@ -1,4 +1,5 @@ from pathlib import Path +from time import sleep from typing import Dict, List, Optional, Tuple import dask.array as da @@ -7,9 +8,11 @@ import numpy.typing as npt from rich.progress import Progress +from brainglobe_stitch.big_stitcher_bridge import run_big_stitcher from brainglobe_stitch.file_utils import ( check_mesospim_directory, create_pyramid_bdv_h5, + get_big_stitcher_transforms, get_slice_attributes, parse_mesospim_metadata, ) @@ -24,15 +27,15 @@ class ImageMosaic: ---------- directory : Path The directory containing the image data. - xml_path : Path | None + xml_path : Optional[Path] The path to the Big Data Viewer XML file. - meta_path : Path | None + meta_path : Optional[Path] The path to the mesoSPIM metadata file. - h5_path : Path | None + h5_path : Optional[Path] The path to the Big Data Viewer h5 file containing the raw data. - tile_config_path : Path | None + tile_config_path : Optional[Path] The path to the BigStitcher tile configuration file. - h5_file : h5py.File | None + h5_file : Optional[h5py.File] An open h5py file object for the raw data. channel_names : List[str] The names of the channels in the image as strings. @@ -80,7 +83,7 @@ def data_for_napari( Parameters ---------- - resolution_level: int + resolution_level : int The resolution level to get the data for. Returns @@ -220,7 +223,7 @@ def write_big_stitcher_tile_config(self, meta_file_name: Path) -> None: Parameters ---------- - meta_file_name: Path + meta_file_name : Path The path to the mesoSPIM metadata file. """ # Remove .h5_meta.txt from the file name @@ -261,3 +264,79 @@ def write_big_stitcher_tile_config(self, meta_file_name: Path) -> None: ) return + + def stitch( + self, + fiji_path: Path, + resolution_level: int, + selected_channel: str, + ) -> None: + """ + Stitch the tiles in the image using BigStitcher. + + Parameters + ---------- + fiji_path : Path + The path to the Fiji application. + resolution_level : int + The resolution level to stitch the tiles at. + selected_channel : str + The name of the channel to stitch. + """ + + # If selected_channel is an empty string then stitch based on + # all channels + all_channels = len(selected_channel) == 0 + channel_int = -1 + + # Extract the wavelength from the channel name + if not all_channels: + try: + channel_int = int(selected_channel.split()[0]) + except ValueError: + raise ValueError("Invalid channel name.") + + # Extract the downsample factors for the selected resolution level + downsample_z, downsample_y, downsample_x = self.tiles[ + 0 + ].resolution_pyramid[resolution_level] + + assert self.xml_path is not None + assert self.tile_config_path is not None + + result = run_big_stitcher( + fiji_path, + self.xml_path, + self.tile_config_path, + all_channels, + channel_int, + downsample_x=downsample_x, + downsample_y=downsample_y, + downsample_z=downsample_z, + ) + + big_stitcher_output_path = self.directory / "big_stitcher_output.txt" + + with open(big_stitcher_output_path, "w") as f: + f.write(result.stdout) + f.write(result.stderr) + + # Print the output of BigStitcher to the command line + print(result.stdout) + + # Wait for the BigStitcher to write XML file + # Need to find a better way to do this + sleep(1) + + self.read_big_stitcher_transforms() + + def read_big_stitcher_transforms(self) -> None: + """ + Read the BigStitcher transforms from the XML file and update the tile + positions accordingly. + """ + assert self.xml_path is not None + stitched_translations = get_big_stitcher_transforms(self.xml_path) + for tile in self.tiles: + stitched_position = stitched_translations[tile.id] + tile.position = stitched_position diff --git a/brainglobe_stitch/stitching_widget.py b/brainglobe_stitch/stitching_widget.py index 3b86ce2..53f04e8 100644 --- a/brainglobe_stitch/stitching_widget.py +++ b/brainglobe_stitch/stitching_widget.py @@ -1,15 +1,17 @@ from pathlib import Path -from typing import List, Optional, Tuple +from typing import Dict, List, Optional, Tuple import dask.array as da import h5py import napari +import numpy as np import numpy.typing as npt from brainglobe_utils.qtpy.logo import header_widget from napari import Viewer from napari.qt.threading import create_worker -from napari.utils.notifications import show_warning +from napari.utils.notifications import show_info, show_warning from qtpy.QtWidgets import ( + QComboBox, QFileDialog, QHBoxLayout, QLabel, @@ -28,7 +30,7 @@ def add_tiles_from_mosaic( - napari_data: List[Tuple[da.Array, npt.NDArray]], tile_names: List[str] + napari_data: List[Tuple[da.Array, npt.NDArray]], image_mosaic: ImageMosaic ): """ Add tiles to the napari viewer from the ImageMosaic. @@ -37,17 +39,35 @@ def add_tiles_from_mosaic( ------------ napari_data : List[Tuple[da.Array, npt.NDArray]] The data and position for each tile in the mosaic. - tile_names : List[str] - The list of tile names. + image_mosaic : ImageMosaic + The ImageMosaic object containing the data for the tiles. """ - - for data, tile_name in zip(napari_data, tile_names): + middle_slice = napari_data[0][0].shape[0] // 2 + thresholds: Dict[str, List[float]] = {} + + for data, tile in zip(napari_data, image_mosaic.tiles): + tile_data, _ = data + curr_threshold = np.percentile( + tile_data[middle_slice].ravel(), 99 + ).compute()[0] + threshold_list = thresholds.get(tile.channel_name, []) + threshold_list.append(curr_threshold) + thresholds[tile.channel_name] = threshold_list + + final_thresholds: Dict[str, float] = dict( + (channel, np.max(thresholds.get(channel))) for channel in thresholds + ) + + for data, tile_name, tile in zip( + napari_data, image_mosaic.tile_names, image_mosaic.tiles + ): + channel_name = tile.channel_name tile_data, tile_position = data tile_layer = napari.layers.Image( tile_data.compute(), name=tile_name, blending="translucent", - contrast_limits=[0, 4000], + contrast_limits=[0, final_thresholds[channel_name]], multiscale=False, ) tile_layer.translate = tile_position @@ -70,6 +90,8 @@ class StitchingWidget(QWidget): The progress bar for the widget, reused for multiple function. image_mosaic : Optional[ImageMosaic] The ImageMosaic object representing the data that will be stitched. + imagej_path : Optional[Path] + The path to the ImageJ executable. tile_layers : List[napari.layers.Image] The list of napari layers containing the tiles. resolution_to_display : int @@ -85,11 +107,21 @@ class StitchingWidget(QWidget): mesospim_directory_text_field : QLineEdit The text field for the mesoSPIM directory. open_file_dialog : QPushButton - The button for opening the file dialog. + The button for opening the file dialog for the mesoSPIM directory. create_pyramid_button : QPushButton The button for creating the resolution pyramid. add_tiles_button : QPushButton The button for adding the tiles to the viewer. + select_imagej_path : QWidget + The widget for selecting the ImageJ path. + imagej_path_text_field : QLineEdit + The text field for the ImageJ path. + open_file_dialog_imagej : QPushButton + The button for opening the file dialog for the ImageJ path.# + fuse_channel_dropdown : QComboBox + The dropdown for selecting the channel to fuse. + stitch_button : QPushButton + The button for stitching the tiles. """ def __init__(self, napari_viewer: Viewer): @@ -97,6 +129,7 @@ def __init__(self, napari_viewer: Viewer): self._viewer = napari_viewer self.progress_bar = QProgressBar(self) self.image_mosaic: Optional[ImageMosaic] = None + self.imagej_path: Optional[Path] = None self.tile_layers: List[napari.layers.Image] = [] self.resolution_to_display: int = 3 @@ -150,6 +183,37 @@ def __init__(self, napari_viewer: Viewer): self.add_tiles_button.setEnabled(False) self.layout().addWidget(self.add_tiles_button) + self.select_imagej_path = QWidget() + self.select_imagej_path.setLayout(QHBoxLayout()) + + self.imagej_path_text_field = QLineEdit() + self.imagej_path_text_field.setText(str(self.default_directory)) + self.imagej_path_text_field.editingFinished.connect( + self._on_imagej_path_text_edited + ) + self.select_imagej_path.layout().addWidget(self.imagej_path_text_field) + + self.open_file_dialog_imagej = QPushButton("Browse") + self.open_file_dialog_imagej.clicked.connect( + self._on_open_file_dialog_imagej_clicked + ) + self.select_imagej_path.layout().addWidget( + self.open_file_dialog_imagej + ) + + self.layout().addWidget(QLabel("Path to ImageJ executable:")) + self.layout().addWidget(self.select_imagej_path) + + self.fuse_channel_dropdown = QComboBox(parent=self) + self.layout().addWidget(self.fuse_channel_dropdown) + + self.stitch_button = QPushButton("Stitch") + self.stitch_button.clicked.connect(self._on_stitch_button_clicked) + self.stitch_button.setEnabled(False) + self.layout().addWidget(self.stitch_button) + + self.layout().addWidget(self.progress_bar) + def _on_open_file_dialog_clicked(self): """ Open a file dialog to select the mesoSPIM directory. @@ -197,12 +261,15 @@ def _on_add_tiles_button_clicked(self): """ self.image_mosaic = ImageMosaic(self.working_directory) + self.fuse_channel_dropdown.clear() + self.fuse_channel_dropdown.addItems(self.image_mosaic.channel_names) + napari_data = self.image_mosaic.data_for_napari( self.resolution_to_display ) worker = create_worker( - add_tiles_from_mosaic, napari_data, self.image_mosaic.tile_names + add_tiles_from_mosaic, napari_data, self.image_mosaic ) worker.yielded.connect(self._set_tile_layers) worker.start() @@ -239,3 +306,74 @@ def check_and_load_mesospim_directory(self): self.add_tiles_button.setEnabled(True) except FileNotFoundError: show_warning("mesoSPIM directory not valid") + + def _on_open_file_dialog_imagej_clicked(self): + """ + Open a file dialog to select the FIJI path. + """ + self.imagej_path = Path( + QFileDialog.getOpenFileName( + self, "Select FIJI Path", str(self.default_directory) + )[0] + ) + self.imagej_path_text_field.setText(str(self.imagej_path)) + self.check_imagej_path() + + def _on_imagej_path_text_edited(self): + """ + Update the FIJI path when the text field is edited. + """ + self.imagej_path = Path(self.imagej_path_text_field.text()) + self.check_imagej_path() + + def _on_stitch_button_clicked(self): + """ + Stitch the tiles in the viewer using BigStitcher. + """ + if self.image_mosaic is None: + show_warning("Open a mesoSPIM directory prior to stitching") + return + + self.image_mosaic.stitch( + self.imagej_path, + resolution_level=2, + selected_channel=self.fuse_channel_dropdown.currentText(), + ) + + show_info("Stitching complete") + + napari_data = self.image_mosaic.data_for_napari( + self.resolution_to_display + ) + + self.update_tiles_from_mosaic(napari_data) + + def check_imagej_path(self): + """ + Check if the selected ImageJ path is valid. If valid, enable the + stitch button. Otherwise, show a warning. + """ + if self.imagej_path.is_file(): + self.stitch_button.setEnabled(True) + else: + show_warning( + "ImageJ path not valid. " + "Please select a valid path to the imageJ executable." + ) + + def update_tiles_from_mosaic( + self, napari_data: List[Tuple[da.Array, npt.NDArray]] + ): + """ + Update the data stored in the napari viewer for each tile based on + the ImageMosaic. + + Parameters + ---------- + napari_data : List[Tuple[da.Array, npt.NDArray]] + The data and position for each tile in the mosaic. + """ + for data, tile_layer in zip(napari_data, self.tile_layers): + tile_data, tile_position = data + tile_layer.data = tile_data.compute() + tile_layer.translate = tile_position diff --git a/brainglobe_stitch/tile.py b/brainglobe_stitch/tile.py index 73daa34..3526c14 100644 --- a/brainglobe_stitch/tile.py +++ b/brainglobe_stitch/tile.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional +from typing import Dict, List import dask.array as da import numpy as np @@ -50,7 +50,7 @@ def __init__( self.neighbours: List[int] = [] self.data_pyramid: List[da.Array] = [] self.resolution_pyramid: npt.NDArray = np.array([]) - self.channel_name: Optional[str] = None + self.channel_name: str = "" self.channel_id: int = int(attributes["channel"]) self.tile_id: int = int(attributes["tile"]) self.illumination_id: int = int(attributes["illumination"]) diff --git a/tests/test_unit/conftest.py b/tests/test_unit/conftest.py index cf87a31..fdce943 100644 --- a/tests/test_unit/conftest.py +++ b/tests/test_unit/conftest.py @@ -1,5 +1,6 @@ import shutil from pathlib import Path +from platform import system import pooch import pytest @@ -10,6 +11,15 @@ @pytest.fixture(scope="session", autouse=True) def download_test_data(): + """ + Downloads the test data and extracts it to a temporary directory. + This fixture is session-scoped and automatically run once. + + Yields + ------ + Path + The path to the temporary directory. + """ TEMP_DIR.mkdir(exist_ok=True) pooch.retrieve( TEST_DATA_URL, @@ -24,11 +34,30 @@ def download_test_data(): @pytest.fixture(scope="session") def test_data_directory(): + """ + Returns the path to the clean test data directory. + + Yields + ------ + Path + The path to the clean test data directory. + """ yield TEMP_DIR @pytest.fixture(scope="module") def naive_bdv_directory(): + """ + Creates a temporary directory and copies the test data to it. This allows + tests to modify the directory without affecting the original test data. + + The temporary directory is cleaned up after the tests are run. + + Yields + ------ + Path + The path to the temporary directory. + """ test_dir = Path.home() / "test_directory" shutil.copytree( @@ -49,6 +78,19 @@ def naive_bdv_directory(): @pytest.fixture def bdv_directory_function_level(): + """ + Creates a temporary directory and copies the test data to it. This allows + tests to modify the directory without affecting the original test data. + + This fixture is function-scoped. + + The temporary directory is cleaned up after the tests are run. + + Yields + ------ + Path + The path to the temporary directory. + """ test_dir = Path.home() / "quick_test_directory" shutil.copytree( @@ -61,3 +103,101 @@ def bdv_directory_function_level(): yield test_dir shutil.rmtree(test_dir) + + +@pytest.fixture(scope="session") +def imagej_path(): + """ + Returns the path to a mock ImageJ executable based on the operating + system. This is used to mimic the behavior of the QFileDialog in macOS, + it returns a path to the "Fiji.app" directory instead of the executable. + + Returns + ------- + Path + The path to the mock ImageJ executable. + """ + if system() == "Windows": + return Path.home() / "Fiji.app/ImageJ-win64.exe" + elif system() == "Darwin": + return Path.home() / "Fiji.app" + else: + return Path.home() / "Fiji.app/ImageJ-linux64" + + +@pytest.fixture(scope="module") +def test_constants(imagej_path): + """ + Provides a dictionary of constants that's used in the tests. + Contains metadata about the test data and the expected results. + + The tiles lie in one z-plane and are arranged in a 2x2 grid. + There are 2 channels. + Each tile is 128x128x107 pixels (x, y, z). + The tiles overlap by 10% in x and y (13 pixels). + The tiles are arranged in the following pattern: + channel 0 | channel 1 + 00 10 | 04 05 + 01 11 | 14 15 + + EXPECTED_TILE_CONFIG is based on test_data_bdv.xml + The tile positions are in pixels in x, y, z order + + EXPECTED_TILE_POSITIONS are based on the stitch transforms found + in test_data_bdv.xml + The tile positions are in pixels in z, y, x order + + Parameters + ---------- + imagej_path : Path + The path to the mock ImageJ executable. + + Returns + ------- + Dict + A dictionary containing the constants. + """ + constants_dict = { + "NUM_TILES": 8, + "NUM_CHANNELS": 2, + "NUM_RESOLUTIONS": 5, + "TILE_SIZE": (107, 128, 128), + "EXPECTED_TILE_CONFIG": [ + "dim=3", + "00;;(0,0,0)", + "01;;(0,115,0)", + "04;;(0,0,0)", + "05;;(0,115,0)", + "10;;(115,0,0)", + "11;;(115,115,0)", + "14;;(115,0,0)", + "15;;(115,115,0)", + ], + "EXPECTED_TILE_POSITIONS": [ + [3, 4, 2], + [2, 120, 0], + [3, 4, 2], + [2, 120, 0], + [6, 7, 118], + [5, 123, 116], + [6, 7, 118], + [5, 123, 116], + ], + "CHANNELS": ["561 nm", "647 nm"], + "PIXEL_SIZE_XY": 4.08, + "PIXEL_SIZE_Z": 5.0, + "MOCK_IMAGEJ_PATH": imagej_path, + # The file dialogue on macOS has a different behaviour + # The selected file path is to the "Fiji.app" directory + # The ImageJ executable is in "Fiji.app/Contents/MacOS/ImageJ-macosx" + "MOCK_IMAGEJ_EXEC_PATH": ( + imagej_path / "Contents/MacOS/ImageJ-macosx" + if system() == "Darwin" + else imagej_path + ), + "MOCK_XML_PATH": Path.home() / "stitching/Brain2/bdv.xml", + "MOCK_TILE_CONFIG_PATH": Path.home() + / "stitching/Brain2/bdv_tile_config.txt", + } + + return constants_dict diff --git a/tests/test_unit/test_big_stitcher_bridge.py b/tests/test_unit/test_big_stitcher_bridge.py new file mode 100644 index 0000000..d9e36bf --- /dev/null +++ b/tests/test_unit/test_big_stitcher_bridge.py @@ -0,0 +1,99 @@ +from importlib.resources import files + +import pytest + +from brainglobe_stitch.big_stitcher_bridge import run_big_stitcher + + +def test_run_big_stitcher_defaults(mocker, test_constants): + """ + Test the run_big_stitcher function with default parameters. Mocks + the subprocess.run function to check if the correct command is + passed and to prevent the actual command from running. + """ + mock_subprocess_run = mocker.patch( + "brainglobe_stitch.big_stitcher_bridge.subprocess.run" + ) + + imagej_path = test_constants["MOCK_IMAGEJ_PATH"] + xml_path = test_constants["MOCK_XML_PATH"] + tile_config_path = test_constants["MOCK_TILE_CONFIG_PATH"] + + run_big_stitcher(imagej_path, xml_path, tile_config_path) + + # Expected path to the ImageJ macro + # Should be in the root of the package + macro_path = files("brainglobe_stitch") / "bigstitcher_macro.ijm" + + expected_imagej_path = test_constants["MOCK_IMAGEJ_EXEC_PATH"] + + command = ( + f"{expected_imagej_path} --ij2" + f" --headless -macro {macro_path} " + f'"{xml_path} {tile_config_path} 0 488 4 4 1"' + ) + + mock_subprocess_run.assert_called_with( + command, capture_output=True, text=True, check=True, shell=True + ) + + +@pytest.mark.parametrize( + "all_channels, selected_channel, downsample_x, downsample_y, downsample_z", + [ + (False, 488, 4, 4, 4), + (True, 488, 4, 4, 4), + (False, 488, 4, 8, 16), + (True, 576, 4, 8, 16), + ], +) +def test_run_big_stitcher( + mocker, + all_channels, + selected_channel, + downsample_x, + downsample_y, + downsample_z, + test_constants, +): + """ + Test the run_big_stitcher function with custom parameters. Mocks + the subprocess.run function to check if the correct command is + passed and to prevent the actual command from running. + """ + mock_subprocess_run = mocker.patch( + "brainglobe_stitch.big_stitcher_bridge.subprocess.run" + ) + + imagej_path = test_constants["MOCK_IMAGEJ_PATH"] + xml_path = test_constants["MOCK_XML_PATH"] + tile_config_path = test_constants["MOCK_TILE_CONFIG_PATH"] + + run_big_stitcher( + imagej_path, + xml_path, + tile_config_path, + all_channels=all_channels, + selected_channel=selected_channel, + downsample_x=downsample_x, + downsample_y=downsample_y, + downsample_z=downsample_z, + ) + + # Expected path to the ImageJ macro + # Should be in the root of the package + macro_path = files("brainglobe_stitch").joinpath("bigstitcher_macro.ijm") + + expected_imagej_path = test_constants["MOCK_IMAGEJ_EXEC_PATH"] + + command = ( + f"{expected_imagej_path} --ij2" + f" --headless -macro {macro_path} " + f'"{xml_path} {tile_config_path} {int(all_channels)} ' + f"{selected_channel} " + f'{downsample_x} {downsample_y} {downsample_z}"' + ) + + mock_subprocess_run.assert_called_with( + command, capture_output=True, text=True, check=True, shell=True + ) diff --git a/tests/test_unit/test_file_utils.py b/tests/test_unit/test_file_utils.py new file mode 100644 index 0000000..7a288ee --- /dev/null +++ b/tests/test_unit/test_file_utils.py @@ -0,0 +1,227 @@ +import shutil +from pathlib import Path + +import h5py +import numpy as np +import pytest + +from brainglobe_stitch.file_utils import ( + check_mesospim_directory, + create_pyramid_bdv_h5, + get_big_stitcher_transforms, + get_slice_attributes, + parse_mesospim_metadata, +) + + +@pytest.fixture +def invalid_bdv_directory(): + """ + Fixture for creating an invalid directory for testing. + The directory is created but empty. The directory is deleted after each + test. + + Yields + ------ + Path + The invalid directory for testing. + """ + bad_dir = Path("./bad_directory") + bad_dir.mkdir() + + yield bad_dir + + shutil.rmtree(bad_dir) + + +def test_create_pyramid_bdv_h5( + naive_bdv_directory, test_data_directory, test_constants +): + """ + Check the create_pyramid_bdv_h5 function. The function should create a + resolution pyramid of depth 5 for each tile in the h5 file. The function + modifies the h5 file in place. The results are checked by comparing the + modified h5 file to the expected h5 file, which is stored in the + test_data_directory. + """ + # Sanity check to ensure that the test h5 file doesn't contain any + # resolutions or subdivisions, and no resolution pyramid. + h5_path = naive_bdv_directory / "test_data_bdv.h5" + with h5py.File(h5_path, "r") as f: + num_tiles = len(f["t00000"].keys()) + tile_names = f["t00000"].keys() + + for tile_name in tile_names: + assert f[f"{tile_name}/resolutions"].shape[0] == 1 + assert f[f"{tile_name}/subdivisions"].shape[0] == 1 + assert len(f[f"t00000/{tile_name}"].keys()) == 1 + + # Run the function and check that the correct value is yielded for each + # iteration (percent complete) + num_done = 1 + for progress in create_pyramid_bdv_h5(h5_path, yield_progress=True): + assert progress == int(100 * num_done / num_tiles) + num_done += 1 + + with ( + h5py.File(h5_path, "r") as f_out, + h5py.File(test_data_directory / "test_data_bdv.h5", "r") as f_in, + ): + # Check that the number of groups/datasets in the parent is unchanged + assert len(f_out.keys()) == len(f_in.keys()) + assert len(f_out["t00000"].keys()) == len(f_in["t00000"].keys()) + + tile_names = f_in["t00000"].keys() + + # Check that the resolutions and subdivisions have been added for + # each tile, and that the resolution pyramid of correct depth has been + # created for each tile. + for tile_name in tile_names: + assert ( + f_out[f"{tile_name}/resolutions"].shape[0] + == test_constants["NUM_RESOLUTIONS"] + ) + assert ( + f_out[f"{tile_name}/subdivisions"].shape[0] + == test_constants["NUM_RESOLUTIONS"] + ) + assert ( + len(f_out[f"t00000/{tile_name}"].keys()) + == test_constants["NUM_RESOLUTIONS"] + ) + + +def test_parse_mesospim_metadata(naive_bdv_directory, test_constants): + """ + Check the parse_mesospim_metadata function. The function should parse the + metadata stored in the h5_meta.txt file and return a list of dictionaries, + one for each tile. The results are checked by comparing the metadata to the + expected metadata, which is stored in the test_constants dictionary. + """ + meta_path = naive_bdv_directory / "test_data_bdv.h5_meta.txt" + + meta_data = parse_mesospim_metadata(meta_path) + + assert len(meta_data) == test_constants["NUM_TILES"] + # The tiles are alternating in channel names + for i in range(test_constants["NUM_TILES"]): + assert meta_data[i]["Laser"] == test_constants["CHANNELS"][i % 2] + assert ( + meta_data[i]["Pixelsize in um"] == test_constants["PIXEL_SIZE_XY"] + ) + assert meta_data[i]["z_stepsize"] == test_constants["PIXEL_SIZE_Z"] + + +def test_check_mesospim_directory(naive_bdv_directory): + xml_path, meta_path, h5_path = check_mesospim_directory( + naive_bdv_directory + ) + + assert xml_path == naive_bdv_directory / "test_data_bdv.xml" + assert meta_path == naive_bdv_directory / "test_data_bdv.h5_meta.txt" + assert h5_path == naive_bdv_directory / "test_data_bdv.h5" + + +@pytest.mark.parametrize( + "file_names, error_message", + [ + ( + ["test_data_bdv.xml", "test_data_bdv.h5_meta.txt"], + "Expected 1 h5 file, found 0", + ), + ( + ["test_data_bdv.xml", "test_data_bdv.h5"], + "Expected 1 h5_meta.txt file, found 0", + ), + ( + ["test_data_bdv.h5_meta.txt", "test_data_bdv.h5"], + "Expected 1 xml file, found 0", + ), + ], +) +def test_check_mesospim_directory_missing_files( + invalid_bdv_directory, file_names, error_message +): + """ + Add the specified files to the invalid directory and check that the + FileNotFoundError is raised with the correct error message. + """ + for file_name in file_names: + Path(invalid_bdv_directory / file_name).touch() + + with pytest.raises(FileNotFoundError) as e: + check_mesospim_directory(invalid_bdv_directory) + + assert error_message in str(e) + + +@pytest.mark.parametrize( + "file_names, error_message", + [ + ( + ["a_bdv.xml", "a_bdv.h5_meta.txt", "a_bdv.h5", "b_bdv.xml"], + "Expected 1 xml file, found 2", + ), + ( + [ + "a_bdv.xml", + "a_bdv.h5_meta.txt", + "a_bdv.h5", + "b_bdv.h5_meta.txt", + ], + "Expected 1 h5_meta.txt file, found 2", + ), + ( + ["a_bdv.xml", "a_bdv.h5_meta.txt", "a_bdv.h5", "b_bdv.h5"], + "Expected 1 h5 file, found 2", + ), + ], +) +def test_check_mesospim_directory_too_many_files( + invalid_bdv_directory, file_names, error_message +): + """ + Add the specified files to the invalid directory and check that the + FileNotFoundError is raised with the correct error message. + """ + for file_name in file_names: + Path(invalid_bdv_directory / file_name).touch() + + with pytest.raises(FileNotFoundError) as e: + check_mesospim_directory(invalid_bdv_directory) + + assert error_message in str(e) + + +def test_get_slice_attributes(naive_bdv_directory, test_constants): + xml_path = naive_bdv_directory / "test_data_bdv.xml" + tile_names = [f"s{i:02}" for i in range(test_constants["NUM_TILES"])] + + slice_attributes = get_slice_attributes(xml_path, tile_names) + + assert len(slice_attributes) == test_constants["NUM_TILES"] + + # The slices are arranged in a 2x2 grid with 2 channels + # The tiles in the test data are arranged in columns per channel + # Each column has its own illumination + # e.g. s00, s01 are channel 0, tile 0 and 1, illumination 0 + # s02, s03 are channel 1, tile 0 and 1, illumination 0 + # s04, s05 are channel 0, tile 2 and 3, illumination 1 + # s06, s07 are channel 1, tile 2 and 3, illumination 1 + for i in range(test_constants["NUM_TILES"]): + assert slice_attributes[tile_names[i]]["channel"] == str((i // 2) % 2) + assert slice_attributes[tile_names[i]]["tile"] == str( + i % 2 + (i // 4) * 2 + ) + assert slice_attributes[tile_names[i]]["illumination"] == str(i // 4) + assert slice_attributes[tile_names[i]]["angle"] == "0" + + +def test_get_big_stitcher_transforms(naive_bdv_directory, test_constants): + xml_path = naive_bdv_directory / "test_data_bdv.xml" + + transforms = get_big_stitcher_transforms(xml_path) + + assert np.equal( + transforms, test_constants["EXPECTED_TILE_POSITIONS"] + ).all() diff --git a/tests/test_unit/test_image_mosaic.py b/tests/test_unit/test_image_mosaic.py index fabef43..527402b 100644 --- a/tests/test_unit/test_image_mosaic.py +++ b/tests/test_unit/test_image_mosaic.py @@ -4,57 +4,32 @@ from brainglobe_stitch.image_mosaic import ImageMosaic -# The tiles lie in one z-plane and are arranged in a 2x2 grid with 2 channels. -# Each tile is 128x128x107 pixels (x, y, z). -# The tiles overlap by 10% in x and y (13 pixels). -# The tiles are arranged in the following pattern: -# channel 0 | channel 1 -# 00 10 | 04 05 -# 01 11 | 14 15 -NUM_TILES = 8 -NUM_RESOLUTIONS = 5 -NUM_CHANNELS = 2 -TILE_SIZE = (107, 128, 128) -# Expected tile config for the test data in test_data_bdv.h5 -# The tile positions are in pixels in x, y, z order -EXPECTED_TILE_CONFIG = [ - "dim=3", - "00;;(0,0,0)", - "01;;(0,115,0)", - "04;;(0,0,0)", - "05;;(0,115,0)", - "10;;(115,0,0)", - "11;;(115,115,0)", - "14;;(115,0,0)", - "15;;(115,115,0)", -] -# Expected tile positions for the test data in test_data_bdv.h5 -# The tile positions are in pixels in z, y, x order -EXPECTED_TILE_POSITIONS = [ - [0, 0, 0], - [0, 115, 0], - [0, 0, 0], - [0, 115, 0], - [0, 0, 115], - [0, 115, 115], - [0, 0, 115], - [0, 115, 115], -] - @pytest.fixture(scope="module") def image_mosaic(naive_bdv_directory): + """ + Fixture for creating an ImageMosaic object for testing. A clean directory + is created for this module using the naive_bdv_directory fixture. Tests + using this fixture will modify the directory. + + The __del__ method is called at the end of the module to close any open h5 + files. + + Yields + ------ + ImageMosaic + An ImageMosaic object for testing. + """ os.remove(naive_bdv_directory / "test_data_bdv_tile_config.txt") image_mosaic = ImageMosaic(naive_bdv_directory) yield image_mosaic - # Explicit call to clean up open h5 files + # Explicit call to close open h5 files image_mosaic.__del__() -def test_image_mosaic_init(image_mosaic, naive_bdv_directory): - image_mosaic = image_mosaic +def test_image_mosaic_init(image_mosaic, naive_bdv_directory, test_constants): assert image_mosaic.xml_path == naive_bdv_directory / "test_data_bdv.xml" assert ( image_mosaic.meta_path @@ -66,15 +41,24 @@ def test_image_mosaic_init(image_mosaic, naive_bdv_directory): == naive_bdv_directory / "test_data_bdv.h5_meta.txt" ) assert image_mosaic.h5_file is not None - assert len(image_mosaic.channel_names) == NUM_CHANNELS - assert len(image_mosaic.tiles) == NUM_TILES - assert len(image_mosaic.tile_names) == NUM_TILES - assert image_mosaic.x_y_resolution == 4.08 - assert image_mosaic.z_resolution == 5.0 - assert image_mosaic.num_channels == NUM_CHANNELS - - -def test_write_big_stitcher_tile_config(image_mosaic, naive_bdv_directory): + assert len(image_mosaic.channel_names) == test_constants["NUM_CHANNELS"] + assert image_mosaic.channel_names == test_constants["CHANNELS"] + assert len(image_mosaic.tiles) == test_constants["NUM_TILES"] + assert len(image_mosaic.tile_names) == test_constants["NUM_TILES"] + assert image_mosaic.x_y_resolution == test_constants["PIXEL_SIZE_XY"] + assert image_mosaic.z_resolution == test_constants["PIXEL_SIZE_Z"] + assert image_mosaic.num_channels == test_constants["NUM_CHANNELS"] + + +def test_write_big_stitcher_tile_config( + image_mosaic, naive_bdv_directory, test_constants +): + """ + Test the write_big_stitcher_tile_config method of the ImageMosaic class. + The expected result is a file with the same contents as + test_constants["EXPECTED_TILE_CONFIG"]. + """ + # Remove the test_data_bdv_tile_config.txt file if it exists if (naive_bdv_directory / "test_data_bdv_tile_config.txt").exists(): os.remove(naive_bdv_directory / "test_data_bdv_tile_config.txt") @@ -85,15 +69,61 @@ def test_write_big_stitcher_tile_config(image_mosaic, naive_bdv_directory): assert (naive_bdv_directory / "test_data_bdv_tile_config.txt").exists() with open(naive_bdv_directory / "test_data_bdv_tile_config.txt", "r") as f: - for idx, line in enumerate(f): - assert line.strip() == EXPECTED_TILE_CONFIG[idx] + for line, expected in zip( + f.readlines(), test_constants["EXPECTED_TILE_CONFIG"] + ): + assert line.strip() == expected + + +def test_stitch(mocker, image_mosaic, naive_bdv_directory, test_constants): + """ + Ensure that the stitch method calls run_big_stitcher with the correct + arguments. + """ + mock_completed_process = mocker.patch( + "subprocess.CompletedProcess", autospec=True + ) + mock_run_big_stitcher = mocker.patch( + "brainglobe_stitch.image_mosaic.run_big_stitcher", + return_value=mock_completed_process, + ) + mock_completed_process.stdout = "" + mock_completed_process.stderr = "" + + fiji_path = test_constants["MOCK_IMAGEJ_PATH"] + resolution_level = 2 + selected_channel = test_constants["CHANNELS"][0] + selected_channel_int = int(selected_channel.split()[0]) + downsample_z, downsample_y, downsample_x = tuple( + image_mosaic.tiles[0].resolution_pyramid[resolution_level] + ) + image_mosaic.stitch(fiji_path, resolution_level, selected_channel) + + mock_run_big_stitcher.assert_called_once_with( + fiji_path, + naive_bdv_directory / "test_data_bdv.xml", + naive_bdv_directory / "test_data_bdv_tile_config.txt", + False, + selected_channel_int, + downsample_x=downsample_x, + downsample_y=downsample_y, + downsample_z=downsample_z, + ) -def test_data_for_napari(image_mosaic): +def test_data_for_napari(image_mosaic, test_constants): + """ + Checks the return of the data_for_napari method. Each element of the + returned list should be a tuple containing the tile data and the expected + position of the tile in the fused image. The expected results are stored + in the dictionary returned by the test_constants fixture. + """ data = image_mosaic.data_for_napari(0) - assert len(data) == NUM_TILES + assert len(data) == test_constants["NUM_TILES"] - for i in range(NUM_TILES): - assert data[i][0].shape == TILE_SIZE - assert (data[i][1] == EXPECTED_TILE_POSITIONS[i]).all() + for tile_data, expected_pos in zip( + data, test_constants["EXPECTED_TILE_POSITIONS"] + ): + assert tile_data[0].shape == test_constants["TILE_SIZE"] + assert (tile_data[1] == expected_pos).all() diff --git a/tests/test_unit/test_stitching_widget.py b/tests/test_unit/test_stitching_widget.py index 980df0e..338ba81 100644 --- a/tests/test_unit/test_stitching_widget.py +++ b/tests/test_unit/test_stitching_widget.py @@ -1,4 +1,5 @@ from pathlib import Path +from typing import Generator import dask.array as da import napari.layers @@ -6,23 +7,45 @@ import pytest import brainglobe_stitch +from brainglobe_stitch.image_mosaic import ImageMosaic from brainglobe_stitch.stitching_widget import ( StitchingWidget, add_tiles_from_mosaic, ) -def test_add_tiles_from_mosaic(): - num_tiles = 4 +@pytest.fixture +def stitching_widget(make_napari_viewer_proxy) -> StitchingWidget: + viewer = make_napari_viewer_proxy() + stitching_widget = StitchingWidget(viewer) + + return stitching_widget - test_data = [] - for i in range(num_tiles): - test_data.append((da.ones((10, 10, 10)), np.array([i, i, i]))) - tile_names = [f"s{i:02}" for i in range(num_tiles)] +@pytest.fixture +def stitching_widget_with_mosaic( + stitching_widget, naive_bdv_directory +) -> Generator[StitchingWidget, None, None]: + stitching_widget.image_mosaic = ImageMosaic(naive_bdv_directory) + + yield stitching_widget + + stitching_widget.image_mosaic.__del__() + + +def test_add_tiles_from_mosaic( + naive_bdv_directory, stitching_widget_with_mosaic +): + """ + Test that the add_tiles_from_mosaic function correctly creates + napari.layers.Image objects from the data the correct values are stored + in the napari.layers.Image objects. + """ + image_mosaic = stitching_widget_with_mosaic.image_mosaic + test_data = image_mosaic.data_for_napari(0) for data, tile in zip( - test_data, add_tiles_from_mosaic(test_data, tile_names) + test_data, add_tiles_from_mosaic(test_data, image_mosaic) ): assert isinstance(tile, napari.layers.Image) assert (tile.data == data[0]).all() @@ -30,6 +53,12 @@ def test_add_tiles_from_mosaic(): def test_stitching_widget_init(make_napari_viewer_proxy): + """ + Test that the StitchingWidget is correctly initialized with the viewer + Currently tests that the viewer is correctly stored, the image_mosaic is + None, the tile_layers list is empty, and the resolution_to_display. + is set to 3. + """ viewer = make_napari_viewer_proxy() stitching_widget = StitchingWidget(viewer) @@ -39,9 +68,15 @@ def test_stitching_widget_init(make_napari_viewer_proxy): assert stitching_widget.resolution_to_display == 3 -def test_on_open_file_dialog_clicked(make_napari_viewer_proxy, mocker): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) +def test_on_open_file_dialog_clicked(stitching_widget, mocker): + """ + Test that the on_open_file_dialog_clicked method correctly sets the + working_directory attribute of the StitchingWidget to the provided + directory. The directory is provided by mocking the return of the + QFileDialog.getExistingDirectory method. The + check_and_load_mesospim_directory method is also mocked to prevent + actually opening and loading the files into the StitchingWidget. + """ test_dir = str(Path.home() / "test_dir") mocker.patch( "brainglobe_stitch.stitching_widget.QFileDialog.getExistingDirectory", @@ -57,9 +92,14 @@ def test_on_open_file_dialog_clicked(make_napari_viewer_proxy, mocker): assert stitching_widget.working_directory == Path(test_dir) -def test_on_mesospim_directory_text_edited(make_napari_viewer_proxy, mocker): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) +def test_on_mesospim_directory_text_edited(stitching_widget, mocker): + """ + Test that the on_mesospim_directory_text_edited method correctly sets + the working_directory attribute of the StitchingWidget to the provided + directory. The directory is provided by setting the text of the mesospim + directory text field. The check_and_load_mesospim_directory is mocked to + prevent actually opening and loading the files into the StitchingWidget. + """ test_dir = str(Path.home() / "test_dir") mocker.patch( "brainglobe_stitch.stitching_widget.StitchingWidget.check_and_load_mesospim_directory", @@ -72,9 +112,14 @@ def test_on_mesospim_directory_text_edited(make_napari_viewer_proxy, mocker): assert stitching_widget.working_directory == Path(test_dir) -def test_on_create_pyramid_button_clicked(make_napari_viewer_proxy, mocker): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) +def test_on_create_pyramid_button_clicked(stitching_widget, mocker): + """ + Test that the on_create_pyramid_button_clicked method correctly calls + the create_worker function with the correct arguments. The create_worker + function is mocked to prevent actually creating the pyramid. The create + pyramid button is disabled should be disabled after the method is called + and the add_tiles_button should be enabled. + """ stitching_widget.h5_path = Path.home() / "test_path" mock_create_worker = mocker.patch( "brainglobe_stitch.stitching_widget.create_worker", @@ -94,10 +139,14 @@ def test_on_create_pyramid_button_clicked(make_napari_viewer_proxy, mocker): def test_on_add_tiles_button_clicked( - make_napari_viewer_proxy, naive_bdv_directory, mocker + stitching_widget, naive_bdv_directory, mocker, test_constants ): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) + """ + Test that the on_add_tiles_button_clicked method correctly calls the + create_worker function once. Following the call to the method, the + image_mosaic attribute should be set to an ImageMosaic object and the + fuse_channel_dropdown should be populated with the correct values. + """ stitching_widget.working_directory = naive_bdv_directory mock_create_worker = mocker.patch( @@ -109,11 +158,21 @@ def test_on_add_tiles_button_clicked( mock_create_worker.assert_called_once() + assert stitching_widget.image_mosaic is not None + + dropdown_values = [ + stitching_widget.fuse_channel_dropdown.itemText(i) + for i in range(stitching_widget.fuse_channel_dropdown.count()) + ] + assert dropdown_values == test_constants["CHANNELS"] + @pytest.mark.parametrize("num_layers", [1, 2, 5]) -def test_set_tile_layers_multiple(make_napari_viewer_proxy, num_layers): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) +def test_set_tile_layers_multiple(stitching_widget, num_layers): + """ + Test that the _set_tile_layers method correctly adds the provided + napari.layers.Image objects to the tile_layers list and to the viewer. + """ test_data = da.ones((10, 10, 10)) test_layers = [] @@ -124,15 +183,18 @@ def test_set_tile_layers_multiple(make_napari_viewer_proxy, num_layers): test_layers.append(test_layer) assert len(stitching_widget.tile_layers) == num_layers - for i in range(num_layers): - assert stitching_widget.tile_layers[i] == test_layers[i] + assert stitching_widget._viewer.layers == test_layers + assert stitching_widget.tile_layers == test_layers def test_check_and_load_mesospim_directory( - make_napari_viewer_proxy, naive_bdv_directory + stitching_widget, naive_bdv_directory ): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) + """ + Sets the working_directory attribute of the StitchingWidget to the + naive_bdv_directory and checks that the correct paths are set for the + StitchingWidget, and that the add_tiles_button is enabled. + """ stitching_widget.working_directory = naive_bdv_directory stitching_widget.check_and_load_mesospim_directory() @@ -149,10 +211,14 @@ def test_check_and_load_mesospim_directory( def test_check_and_load_mesospim_directory_no_pyramid( - make_napari_viewer_proxy, bdv_directory_function_level, mocker + stitching_widget, bdv_directory_function_level, mocker ): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) + """ + Uses the bdv_directory_function_level fixture to create a clean + mesospim directory. This should trigger the show_warning method to + inform the user that the resolution pyramid was not found and enable + the create_pyramid_button. + """ stitching_widget.working_directory = bdv_directory_function_level mock_show_warning = mocker.patch( @@ -171,13 +237,17 @@ def test_check_and_load_mesospim_directory_no_pyramid( ["test_data_bdv.h5", "test_data_bdv.xml", "test_data_bdv.h5_meta.txt"], ) def test_check_and_load_mesospim_directory_missing_files( - make_napari_viewer_proxy, + stitching_widget, bdv_directory_function_level, mocker, file_to_remove, ): - viewer = make_napari_viewer_proxy() - stitching_widget = StitchingWidget(viewer) + """ + Uses the bdv_directory_function_level fixture to create a clean + mesospim directory and then remove one of the files (file_to_remove). + This should trigger a show_warning message to inform the user that the + mesoSPIM directory is not valid. + """ stitching_widget.working_directory = bdv_directory_function_level error_message = "mesoSPIM directory not valid" @@ -189,3 +259,142 @@ def test_check_and_load_mesospim_directory_missing_files( stitching_widget.check_and_load_mesospim_directory() mock_show_warning.assert_called_once_with(error_message) + + +def test_on_open_file_dialog_imagej_clicked(stitching_widget, mocker): + """ + Mocks the QFileDialog.getOpenFileName method to return a mock imageJ + directory. The check_imagej_path method is also mocked as the path doesn't + point to a valid imageJ executable. The imageJ path text field should have + the mock imageJ directory and the imageJ path attribute of the stitching + widget should be set to the mock imageJ directory. + """ + imagej_dir = str(Path.home() / "imageJ") + mocker.patch( + "brainglobe_stitch.stitching_widget.QFileDialog.getOpenFileName", + return_value=(imagej_dir, ""), + ) + mocker.patch( + "brainglobe_stitch.stitching_widget.StitchingWidget.check_imagej_path", + ) + + stitching_widget._on_open_file_dialog_imagej_clicked() + + assert stitching_widget.imagej_path_text_field.text() == imagej_dir + assert stitching_widget.imagej_path == Path(imagej_dir) + + +def test_on_imagej_path_text_edited(stitching_widget, mocker): + """ + Manually sets the imageJ path text field to a mock imageJ directory to + mimic a user manually entering or copying a path to imageJ. The imagej_path + attribute of the StitchingWidget should be set to the mock directory. + """ + imagej_dir = str(Path.home() / "imageJ") + mocker.patch( + "brainglobe_stitch.stitching_widget.StitchingWidget.check_imagej_path", + ) + + stitching_widget.imagej_path_text_field.setText(imagej_dir) + + stitching_widget._on_imagej_path_text_edited() + + assert stitching_widget.imagej_path == Path(imagej_dir) + + +def test_on_stitch_button_clicked( + stitching_widget_with_mosaic, naive_bdv_directory, mocker +): + """ + Uses the stitching_widget_with_mosaic fixture to create a StitchingWidget + with an ImageMosaic object. The mock_stitch_function is used to prevent + the actual stitching of the ImageMosaic object. + Tests that the _on_stitch_button_clicked method correctly calls the stitch + method of the ImageMosaic object with the correct arguments. + """ + stitching_widget = stitching_widget_with_mosaic + + mock_stitch_function = mocker.patch( + "brainglobe_stitch.stitching_widget.ImageMosaic.stitch", + autospec=True, + ) + + stitching_widget._on_stitch_button_clicked() + + mock_stitch_function.assert_called_once_with( + stitching_widget.image_mosaic, + stitching_widget.imagej_path, + resolution_level=2, + selected_channel="", + ) + + +def test_check_imagej_path_valid(stitching_widget): + """ + Creates a mock imageJ file in the home directory and sets it as the + imageJ path of the StitchingWidget. The check_imagej_path method should + enable the stitch button as the path is valid. + + The mock imageJ file is removed after the test. + """ + stitching_widget.imagej_path = Path.home() / "imageJ" + stitching_widget.imagej_path.touch(exist_ok=True) + stitching_widget.check_imagej_path() + # Clean up before assertions to make sure nothing is left behind + # regardless of test outcome + stitching_widget.imagej_path.unlink() + + assert stitching_widget.stitch_button.isEnabled() + + +def test_check_imagej_path_invalid(stitching_widget, mocker): + """ + Sets the imageJ path of the StitchingWidget to a non-existent directory. + The check_imagej_path method should show a warning message to the user. + """ + stitching_widget.imagej_path = Path.home() / "imageJ" + + mock_show_warning = mocker.patch( + "brainglobe_stitch.stitching_widget.show_warning" + ) + error_message = ( + "ImageJ path not valid. " + "Please select a valid path to the imageJ executable." + ) + + stitching_widget.check_imagej_path() + + mock_show_warning.assert_called_once_with(error_message) + + +def test_update_tiles_from_mosaic( + stitching_widget_with_mosaic, naive_bdv_directory, test_constants +): + """ + Uses the stitching_widget_with_mosaic fixture to create a StitchingWidget + with an ImageMosaic object. The tiles from the ImageMosaic object are + added to the tile_layers list. The update_tiles_from_mosaic method is + called with mock data and offsets. The data and offset of each + napari.layers.Image are checked. + """ + stitching_widget = stitching_widget_with_mosaic + num_tiles = test_constants["NUM_TILES"] + test_data = [] + + initial_data = stitching_widget.image_mosaic.data_for_napari(0) + + for tile in add_tiles_from_mosaic( + initial_data, stitching_widget.image_mosaic + ): + stitching_widget.tile_layers.append(tile) + + for i in range(num_tiles): + test_data.append( + (da.ones(initial_data[0][0].shape) + i, np.array([i, i, i])) + ) + + stitching_widget.update_tiles_from_mosaic(test_data) + + for tile, test_data in zip(stitching_widget.tile_layers, test_data): + assert (tile.data == test_data[0]).all() + assert (tile.translate == test_data[1]).all() diff --git a/tests/test_unit/test_tile.py b/tests/test_unit/test_tile.py index eb9ed2d..6c6d6c1 100644 --- a/tests/test_unit/test_tile.py +++ b/tests/test_unit/test_tile.py @@ -24,6 +24,6 @@ def test_tile_init(): assert len(tile.data_pyramid) == 0 assert len(tile.resolution_pyramid) == 0 assert tile.channel_id == channel_id - assert tile.channel_name is None + assert tile.channel_name == "" assert tile.illumination_id == illumination_id assert tile.angle == angle