diff --git a/containers/main.py b/containers/main.py index b0f4b43..07c4f5d 100644 --- a/containers/main.py +++ b/containers/main.py @@ -127,33 +127,12 @@ def main( filterwarnings(action="ignore", message="No cached namespaces found in .*") filterwarnings(action="ignore", message="Ignoring cached namespace .*") - # Set SpikeInterface global job kwargs - si.set_global_job_kwargs( - n_jobs=os.cpu_count(), - chunk_duration="1s", - progress_bar=False - ) - - # Create folders - data_folder = Path("data/") - data_folder.mkdir(exist_ok=True) - - scratch_folder = Path("scratch/") - scratch_folder.mkdir(exist_ok=True) - - results_folder = Path("results/") - results_folder.mkdir(exist_ok=True) - tmp_folder = scratch_folder / "tmp" - if tmp_folder.is_dir(): - shutil.rmtree(tmp_folder) - tmp_folder.mkdir() - # Checks if source_name not in ["local", "s3", "dandi"]: logger.error(f"Source {source_name} not supported. Choose from: local, s3, dandi.") raise ValueError(f"Source {source_name} not supported. Choose from: local, s3, dandi.") - # TODO: here we could leverage spikeinterface and add more options + # TODO: here we could eventually leverage spikeinterface and add more options if source_data_type not in ["nwb", "spikeglx"]: logger.error(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.") raise ValueError(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.") @@ -174,7 +153,28 @@ def main( output_s3_bucket = output_path_parsed.split("/")[0] output_s3_bucket_folder = "/".join(output_path_parsed.split("/")[1:]) - s3_client = boto3.client("s3") + # Set SpikeInterface global job kwargs + si.set_global_job_kwargs( + n_jobs=os.cpu_count(), + chunk_duration="1s", + progress_bar=False + ) + + # Create SpikeInterface folders + data_folder = Path("data/") + data_folder.mkdir(exist_ok=True) + scratch_folder = Path("scratch/") + scratch_folder.mkdir(exist_ok=True) + results_folder = Path("results/") + results_folder.mkdir(exist_ok=True) + tmp_folder = scratch_folder / "tmp" + if tmp_folder.is_dir(): + shutil.rmtree(tmp_folder) + tmp_folder.mkdir() + + # S3 client + if source_name == "s3" or output_destination == "s3": + s3_client = boto3.client("s3") # Test with toy recording if test_with_toy_recording: @@ -198,7 +198,6 @@ def main( bucket_name=bucket_name, file_path=file_path, ) - logger.info("Reading recording...") # E.g.: se.read_spikeglx(folder_path="/data", stream_id="imec.ap") if source_data_type == "spikeglx": @@ -207,17 +206,16 @@ def main( recording = se.read_nwb_recording(file_path=f"/data/{file_name}", **recording_kwargs) recording_name = "recording_on_s3" + # Load data from DANDI archive elif source_name == "dandi": dandiset_s3_file_url = source_data_paths["file"] if not dandiset_s3_file_url.startswith("https://dandiarchive"): raise Exception( f"DANDISET_S3_FILE_URL should be a valid Dandiset S3 url. Value received was: {dandiset_s3_file_url}" ) - if not test_with_subrecording: logger.info(f"Downloading dataset: {dandiset_s3_file_url}") download_file_from_url(dandiset_s3_file_url) - logger.info("Reading recording from NWB...") recording = se.read_nwb_recording(file_path="/data/filename.nwb", **recording_kwargs) else: @@ -225,12 +223,18 @@ def main( recording = se.read_nwb_recording(file_path=dandiset_s3_file_url, stream_mode="fsspec", **recording_kwargs) recording_name = "recording_on_dandi" + # TODO - Load data from local files + elif source_name == "local": + pass + + # Run with subrecording if test_with_subrecording: n_frames = int(min(test_subrecording_n_frames, recording.get_num_frames())) recording = recording.frame_slice(start_frame=0, end_frame=n_frames) # ------------------------------------------------------------------------------------ # Preprocessing + # ------------------------------------------------------------------------------------ logger.info("Starting preprocessing...") preprocessing_notes = "" preprocessed_folder = tmp_folder / "preprocessed" @@ -238,13 +242,13 @@ def main( logger.info(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") if "inter_sample_shift" in recording.get_property_keys(): - recording_ps_full = spre.phase_shift(recording, **preprocessing_params["phase_shift"]) + recording_ps_full = spre.phase_shift(recording, **preprocessing_kwargs["phase_shift"]) else: recording_ps_full = recording - recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_params["highpass_filter"]) + recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_kwargs["highpass_filter"]) # IBL bad channel detection - _, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_params["detect_bad_channels"]) + _, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_kwargs["detect_bad_channels"]) dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" out_channel_mask = channel_labels == "out" @@ -257,7 +261,7 @@ def main( out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) - max_bad_channel_fraction_to_remove = preprocessing_params["max_bad_channel_fraction_to_remove"] + max_bad_channel_fraction_to_remove = preprocessing_kwargs["max_bad_channel_fraction_to_remove"] if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): logger.info( f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " @@ -267,28 +271,28 @@ def main( # in this case, we don't bother sorting return else: - if preprocessing_params["remove_out_channels"]: + if preprocessing_kwargs["remove_out_channels"]: logger.info(f"\tRemoving {len(out_channel_ids)} out channels") recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." else: recording_rm_out = recording_hp_full - recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_params["common_reference"]) + recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_kwargs["common_reference"]) bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) recording_hp_spatial = spre.highpass_spatial_filter( - recording_interp, **preprocessing_params["highpass_spatial_filter"] + recording_interp, **preprocessing_kwargs["highpass_spatial_filter"] ) - preproc_strategy = preprocessing_params["preprocessing_strategy"] + preproc_strategy = preprocessing_kwargs["preprocessing_strategy"] if preproc_strategy == "cmr": recording_processed = recording_processed_cmr else: recording_processed = recording_hp_spatial - if preprocessing_params["remove_bad_channels"]: + if preprocessing_kwargs["remove_bad_channels"]: logger.info(f"\tRemoving {len(bad_channel_ids)} channels after {preproc_strategy} preprocessing") recording_processed = recording_processed.remove_channels(bad_channel_ids) preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" @@ -300,6 +304,8 @@ def main( # ------------------------------------------------------------------------------------ # Spike Sorting + # ------------------------------------------------------------------------------------ + sorter_name = sorter_kwargs["sorter_name"] logger.info(f"\n\nStarting spike sorting with {sorter_name}") spikesorting_notes = "" sorting_params = None @@ -313,7 +319,7 @@ def main( if recording_processed.get_num_segments() > 1: recording_processed = si.concatenate_recordings([recording_processed]) - # run ks2.5 + # Run sorter try: sorting = ss.run_sorter( sorter_name, @@ -334,6 +340,7 @@ def main( # remove empty units sorting = sorting.remove_empty_units() + # remove spikes beyond num_samples (if any) sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording_processed) logger.info(f"\tSorting output without empty units: {sorting}") @@ -353,6 +360,7 @@ def main( # ------------------------------------------------------------------------------------ # Postprocessing + # ------------------------------------------------------------------------------------ logger.info("\n\Starting postprocessing...") postprocessing_notes = "" t_postprocessing_start = time.perf_counter() @@ -360,7 +368,7 @@ def main( # first extract some raw waveforms in memory to deduplicate based on peak alignment wf_dedup_folder = tmp_folder / "postprocessed" / recording_name we_raw = si.extract_waveforms( - recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_params["waveforms_deduplicate"] + recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_kwargs["waveforms_deduplicate"] ) # de-duplication sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_params["duplicate_threshold"]) @@ -389,32 +397,34 @@ def main( sparsity=sparsity, sparse=True, overwrite=True, - **postprocessing_params["waveforms"], + **postprocessing_kwargs["waveforms"], ) logger.info("\tComputing spike amplitides") - spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_params["spike_amplitudes"]) + spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_kwargs["spike_amplitudes"]) logger.info("\tComputing unit locations") - unit_locations = spost.compute_unit_locations(we, **postprocessing_params["locations"]) + unit_locations = spost.compute_unit_locations(we, **postprocessing_kwargs["locations"]) logger.info("\tComputing spike locations") - spike_locations = spost.compute_spike_locations(we, **postprocessing_params["locations"]) + spike_locations = spost.compute_spike_locations(we, **postprocessing_kwargs["locations"]) logger.info("\tComputing correlograms") - ccg, bins = spost.compute_correlograms(we, **postprocessing_params["correlograms"]) + ccg, bins = spost.compute_correlograms(we, **postprocessing_kwargs["correlograms"]) logger.info("\tComputing ISI histograms") - isi, bins = spost.compute_isi_histograms(we, **postprocessing_params["isis"]) + isi, bins = spost.compute_isi_histograms(we, **postprocessing_kwargs["isis"]) logger.info("\tComputing template similarity") - sim = spost.compute_template_similarity(we, **postprocessing_params["similarity"]) + sim = spost.compute_template_similarity(we, **postprocessing_kwargs["similarity"]) logger.info("\tComputing template metrics") - tm = spost.compute_template_metrics(we, **postprocessing_params["template_metrics"]) + tm = spost.compute_template_metrics(we, **postprocessing_kwargs["template_metrics"]) logger.info("\tComputing PCA") - pca = spost.compute_principal_components(we, **postprocessing_params["principal_components"]) + pca = spost.compute_principal_components(we, **postprocessing_kwargs["principal_components"]) logger.info("\tComputing quality metrics") - qm = sqm.compute_quality_metrics(we, **postprocessing_params["quality_metrics"]) + qm = sqm.compute_quality_metrics(we, **postprocessing_kwargs["quality_metrics"]) t_postprocessing_end = time.perf_counter() elapsed_time_postprocessing = np.round(t_postprocessing_end - t_postprocessing_start, 2) logger.info(f"Postprocessing time: {elapsed_time_postprocessing}s") - ###### CURATION ############## + # ------------------------------------------------------------------------------------ + # Curation + # ------------------------------------------------------------------------------------ logger.info("\n\Starting curation...") curation_notes = "" t_curation_start = time.perf_counter() @@ -452,10 +462,15 @@ def main( elapsed_time_curation = np.round(t_curation_end - t_curation_start, 2) logger.info(f"Curation time: {elapsed_time_curation}s") + # ------------------------------------------------------------------------------------ # TODO: Visualization with FIGURL (needs credentials) + # ------------------------------------------------------------------------------------ + + # ------------------------------------------------------------------------------------ # Conversion and upload + # ------------------------------------------------------------------------------------ logger.info("Writing sorting results to NWB...") metadata = { "NWBFile": { @@ -603,7 +618,7 @@ def main( visualization_kwargs = json.loads(os.environ.get("SI_VISUALIZATION_KWARGS", "{}")) # Get output kwargs from ENV variables - output_kwargs = json.loads(os.environ.get("SI_OUTPUT_KWARGS", "{}")) + output_kwargs = json.loads(os.environ.get("SI_OUTPUT_DATA_KWARGS", "{}")) output_destination = validate_not_none(output_kwargs, "output_destination") output_path = validate_not_none(output_kwargs, "output_path") diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py index ed738b0..fbd2042 100644 --- a/rest/clients/local_docker.py +++ b/rest/clients/local_docker.py @@ -11,6 +11,7 @@ PostprocessingKwargs, CurationKwargs, VisualizationKwargs, + OutputDataKwargs ) @@ -30,6 +31,7 @@ def run_sorting( postprocessing_kwargs: PostprocessingKwargs, curation_kwargs: CurationKwargs, visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, ) -> None: # Pass kwargs as environment variables to the container env_vars = dict( @@ -41,6 +43,7 @@ def run_sorting( SI_POSTPROCESSING_KWARGS=postprocessing_kwargs.json(), SI_CURATION_KWARGS=curation_kwargs.json(), SI_VISUALIZATION_KWARGS=visualization_kwargs.json(), + SI_OUTPUT_DATA_KWARGS=output_data_kwargs.json(), ) # Local volumes to mount diff --git a/rest/models/sorting.py b/rest/models/sorting.py index e78942f..33f71e7 100644 --- a/rest/models/sorting.py +++ b/rest/models/sorting.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field, Extra from typing import Optional, Dict, List, Union, Tuple from enum import Enum +from datetime import datetime # ------------------------------ @@ -10,10 +11,13 @@ class RunAt(str, Enum): aws = "aws" local = "local" +def default_run_identifier(): + return datetime.now().strftime("%Y%m%d-%H%M%S") + class RunKwargs(BaseModel): run_at: RunAt = Field(..., description="Where to run the sorting job. Choose from: aws, local.") - run_identifier: str = Field(..., description="Unique identifier for the run.") - run_description: str = Field(..., description="Description of the run.") + run_identifier: str = Field(default_factory=default_run_identifier, description="Unique identifier for the run.") + run_description: str = Field(default="", description="Description of the run.") test_with_toy_recording: bool = Field(default=False, description="Whether to test with a toy recording.") test_with_subrecording: bool = Field(default=False, description="Whether to test with a subrecording.") test_subrecording_n_frames: Optional[int] = Field(default=30000, description="Number of frames to use for the subrecording.") @@ -278,4 +282,3 @@ class VisualizationKwargs(BaseModel): timeseries: Timeseries drift: Drift - diff --git a/rest/routes/sorting.py b/rest/routes/sorting.py index 188724c..30f8dac 100644 --- a/rest/routes/sorting.py +++ b/rest/routes/sorting.py @@ -17,6 +17,7 @@ PostprocessingKwargs, CurationKwargs, VisualizationKwargs, + OutputDataKwargs ) @@ -32,6 +33,7 @@ def sorting_background_task( postprocessing_kwargs: PostprocessingKwargs, curation_kwargs: CurationKwargs, visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, ): # Run sorting and update db entry status db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) @@ -49,6 +51,7 @@ def sorting_background_task( postprocessing_kwargs=postprocessing_kwargs, curation_kwargs=curation_kwargs, visualization_kwargs=visualization_kwargs, + output_data_kwargs=output_data_kwargs, ) elif run_at == "aws": # TODO: Implement this @@ -77,12 +80,9 @@ async def route_run_sorting( postprocessing_kwargs: PostprocessingKwargs, curation_kwargs: CurationKwargs, visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, background_tasks: BackgroundTasks ) -> JSONResponse: - if not run_kwargs.run_identifier: - run_identifier = datetime.now().strftime("%Y%m%d%H%M%S") - else: - run_identifier = run_kwargs.run_identifier try: # Create Database entries db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) @@ -121,6 +121,7 @@ async def route_run_sorting( postprocessing_kwargs=postprocessing_kwargs, curation_kwargs=curation_kwargs, visualization_kwargs=visualization_kwargs, + output_data_kwargs=output_data_kwargs, ) except Exception as e: