diff --git a/.gitignore b/.gitignore index a9d0917..8801548 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ yarn-error.log* # Miscellaneous *.fuse_hidden* results/ +logs/ *.nwb \ No newline at end of file diff --git a/containers/Dockerfile.ks2_5 b/containers/Dockerfile.ks2_5 index fc99494..90bb79b 100644 --- a/containers/Dockerfile.ks2_5 +++ b/containers/Dockerfile.ks2_5 @@ -1,34 +1,34 @@ # Spike sorters image -FROM spikeinterface/kilosort2_5-compiled-base:0.2.0 as ks25base - -# NVIDIA-ready Image -FROM nvidia/cuda:11.6.2-base-ubuntu20.04 - -# Installing Python with miniconda -RUN apt-get update && \ - apt-get install -y build-essential && \ - apt-get install -y wget && \ - apt-get install -y git && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -ENV CONDA_DIR /home/miniconda3 -ENV LATEST_CONDA_SCRIPT "Miniconda3-py39_23.5.2-0-Linux-x86_64.sh" - -RUN wget --quiet https://repo.anaconda.com/miniconda/$LATEST_CONDA_SCRIPT -O ~/miniconda.sh && \ - bash ~/miniconda.sh -b -p $CONDA_DIR && \ - rm ~/miniconda.sh -ENV PATH=$CONDA_DIR/bin:$PATH - -# Bring Sorter and matlab-related files -COPY --from=ks25base /usr/bin/mlrtapp/ks2_5_compiled /usr/bin/mlrtapp/ks2_5_compiled -ENV PATH="/usr/bin/mlrtapp:${PATH}" -COPY --from=ks25base /opt/matlabruntime /opt/matlabruntime -ENV PATH="/opt/matlabruntime:${PATH}" -COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libXt.so.6 /usr/lib/x86_64-linux-gnu/libXt.so.6 -COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libSM.so.6 /usr/lib/x86_64-linux-gnu/libSM.so.6 -COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libICE.so.6 /usr/lib/x86_64-linux-gnu/libICE.so.6 -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/matlabruntime/R2022b/runtime/glnxa64:/opt/matlabruntime/R2022b/bin/glnxa64:/opt/matlabruntime/R2022b/sys/os/glnxa64:/opt/matlabruntime/R2022b/sys/opengl/lib/glnxa64:/opt/matlabruntime/R2022b/extern/bin/glnxa64 +FROM spikeinterface/kilosort2_5-compiled-base:0.2.0 + +# # NVIDIA-ready Image +# FROM nvidia/cuda:11.6.2-base-ubuntu20.04 + +# # Installing Python with miniconda +# RUN apt-get update && \ +# apt-get install -y build-essential && \ +# apt-get install -y wget && \ +# apt-get install -y git && \ +# apt-get clean && \ +# rm -rf /var/lib/apt/lists/* + +# ENV CONDA_DIR /home/miniconda3 +# ENV LATEST_CONDA_SCRIPT "Miniconda3-py39_23.5.2-0-Linux-x86_64.sh" + +# RUN wget --quiet https://repo.anaconda.com/miniconda/$LATEST_CONDA_SCRIPT -O ~/miniconda.sh && \ +# bash ~/miniconda.sh -b -p $CONDA_DIR && \ +# rm ~/miniconda.sh +# ENV PATH=$CONDA_DIR/bin:$PATH + +# # Bring Sorter and matlab-related files +# COPY --from=ks25base /usr/bin/mlrtapp/ks2_5_compiled /usr/bin/mlrtapp/ks2_5_compiled +# ENV PATH="/usr/bin/mlrtapp:${PATH}" +# COPY --from=ks25base /opt/matlabruntime /opt/matlabruntime +# ENV PATH="/opt/matlabruntime:${PATH}" +# COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libXt.so.6 /usr/lib/x86_64-linux-gnu/libXt.so.6 +# COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libSM.so.6 /usr/lib/x86_64-linux-gnu/libSM.so.6 +# COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libICE.so.6 /usr/lib/x86_64-linux-gnu/libICE.so.6 +# ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/matlabruntime/R2022b/runtime/glnxa64:/opt/matlabruntime/R2022b/bin/glnxa64:/opt/matlabruntime/R2022b/sys/os/glnxa64:/opt/matlabruntime/R2022b/sys/opengl/lib/glnxa64:/opt/matlabruntime/R2022b/extern/bin/glnxa64 # Copy requirements and script COPY requirements.txt . @@ -37,11 +37,11 @@ RUN pip install -r requirements.txt WORKDIR /app COPY main.py . COPY utils.py . -COPY light_server.py . +# COPY light_server.py . RUN mkdir /data RUN mkdir /logs # Get Python stdout logs ENV PYTHONUNBUFFERED=1 -CMD ["python", "light_server.py"] \ No newline at end of file +CMD ["python", "main.py"] \ No newline at end of file diff --git a/containers/README.md b/containers/README.md index c85cdd0..0d22f95 100644 --- a/containers/README.md +++ b/containers/README.md @@ -16,7 +16,8 @@ Basic infrastructure makes use of the following AWS services: Build docker image: ```bash -$ DOCKER_BUILDKIT=1 docker build -t -f . +$ DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/si-sorting-ks25:latest -f Dockerfile.ks2_5 . +$ DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/si-sorting-ks3:latest -f Dockerfile.ks3 . ``` Run locally: diff --git a/containers/README_costs.md b/containers/README_costs.md index 7101b50..6ab4de2 100644 --- a/containers/README_costs.md +++ b/containers/README_costs.md @@ -5,7 +5,7 @@ ECR: - transfer out: $0.09 per GB transferred S3: -- storage: $0.023 per GB / mont +- storage: $0.023 per GB / month - requests: $0.005 (PUT, COPY, POST, LIST) or $0.0004 (GET, SELECT) - transfer out: 100GB free per month, after that $0.02 per GB (other aws services) or $0.09 per GB (outside of aws) diff --git a/containers/example_formats.py b/containers/example_formats.py new file mode 100644 index 0000000..6e1c70a --- /dev/null +++ b/containers/example_formats.py @@ -0,0 +1,154 @@ + + +source_data = { + "source": "dandi", # or "s3" + "source_data_type": "nwb", # or "spikeglx" + "source_data_paths": { + "file": "https://dandiarchive.org/dandiset/000003/0.210813.1807" + }, + "recording_kwargs": { + "electrical_series_name": "ElectricalSeries", + } +} + + +source_data_2 = { + "source": "s3", + "source_data_type": "spikeglx", + "source_data_paths": { + "file_bin": "s3://bucket/path/to/file.ap.bin", + "file_meta": "s3://bucket/path/to/file2.ap.meta", + }, + "recording_kwargs": {} +} + + +preprocessing_params = dict( + preprocessing_strategy="cmr", # 'destripe' or 'cmr' + highpass_filter=dict(freq_min=300.0, margin_ms=5.0), + phase_shift=dict(margin_ms=100.0), + detect_bad_channels=dict( + method="coherence+psd", + dead_channel_threshold=-0.5, + noisy_channel_threshold=1.0, + outside_channel_threshold=-0.3, + n_neighbors=11, + seed=0, + ), + remove_out_channels=False, + remove_bad_channels=False, + max_bad_channel_fraction_to_remove=1.1, + common_reference=dict(reference="global", operator="median"), + highpass_spatial_filter=dict( + n_channel_pad=60, + n_channel_taper=None, + direction="y", + apply_agc=True, + agc_window_length_s=0.01, + highpass_butter_order=3, + highpass_butter_wn=0.01, + ), +) + +qm_params = { + "presence_ratio": {"bin_duration_s": 60}, + "snr": {"peak_sign": "neg", "peak_mode": "extremum", "random_chunk_kwargs_dict": None}, + "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0}, + "rp_violation": {"refractory_period_ms": 1, "censored_period_ms": 0.0}, + "sliding_rp_violation": { + "bin_size_ms": 0.25, + "window_size_s": 1, + "exclude_ref_period_below_ms": 0.5, + "max_ref_period_ms": 10, + "contamination_values": None, + }, + "amplitude_cutoff": { + "peak_sign": "neg", + "num_histogram_bins": 100, + "histogram_smoothing_value": 3, + "amplitudes_bins_min_ratio": 5, + }, + "amplitude_median": {"peak_sign": "neg"}, + "nearest_neighbor": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4}, + "nn_isolation": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, + "nn_noise_overlap": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, +} +qm_metric_names = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violation", + "rp_violation", + "sliding_rp_violation", + "amplitude_cutoff", + "drift", + "isolation_distance", + "l_ratio", + "d_prime", +] + +postprocessing_params = dict( + sparsity=dict(method="radius", radius_um=100), + waveforms_deduplicate=dict( + ms_before=0.5, + ms_after=1.5, + max_spikes_per_unit=100, + return_scaled=False, + dtype=None, + precompute_template=("average",), + use_relative_path=True, + ), + waveforms=dict( + ms_before=3.0, + ms_after=4.0, + max_spikes_per_unit=500, + return_scaled=True, + dtype=None, + precompute_template=("average", "std"), + use_relative_path=True, + ), + spike_amplitudes=dict( + peak_sign="neg", + return_scaled=True, + outputs="concatenated", + ), + similarity=dict(method="cosine_similarity"), + correlograms=dict( + window_ms=100.0, + bin_ms=2.0, + ), + isis=dict( + window_ms=100.0, + bin_ms=5.0, + ), + locations=dict(method="monopolar_triangulation"), + template_metrics=dict(upsampling_factor=10, sparsity=None), + principal_components=dict(n_components=5, mode="by_channel_local", whiten=True), + quality_metrics=dict( + qm_params=qm_params, + metric_names=qm_metric_names, + n_jobs=1 + ), +) + +curation_params = dict( + duplicate_threshold=0.9, + isi_violations_ratio_threshold=0.5, + presence_ratio_threshold=0.8, + amplitude_cutoff_threshold=0.1, +) + +visualization_params = dict( + timeseries=dict(n_snippets_per_segment=2, snippet_duration_s=0.5, skip=False), + drift=dict( + detection=dict(method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + localization=dict(ms_before=0.1, ms_after=0.3, local_radius_um=100.0), + n_skip=30, + alpha=0.15, + vmin=-200, + vmax=0, + cmap="Greys_r", + figsize=(10, 10), + ), +) diff --git a/containers/main.py b/containers/main.py index 0e108e7..79d00b5 100644 --- a/containers/main.py +++ b/containers/main.py @@ -1,4 +1,5 @@ import boto3 +import json import os import ast import subprocess @@ -33,6 +34,7 @@ from utils import ( make_logger, + validate_not_none, download_file_from_s3, upload_file_to_bucket, upload_all_files_to_bucket_folder, @@ -40,183 +42,31 @@ ) -### PARAMS: # TODO: probably we should pass a JSON file -n_jobs = os.cpu_count() -job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=False) - -preprocessing_params = dict( - preprocessing_strategy="cmr", # 'destripe' or 'cmr' - highpass_filter=dict(freq_min=300.0, margin_ms=5.0), - phase_shift=dict(margin_ms=100.0), - detect_bad_channels=dict( - method="coherence+psd", - dead_channel_threshold=-0.5, - noisy_channel_threshold=1.0, - outside_channel_threshold=-0.3, - n_neighbors=11, - seed=0, - ), - remove_out_channels=False, - remove_bad_channels=False, - max_bad_channel_fraction_to_remove=1.1, - common_reference=dict(reference="global", operator="median"), - highpass_spatial_filter=dict( - n_channel_pad=60, - n_channel_taper=None, - direction="y", - apply_agc=True, - agc_window_length_s=0.01, - highpass_butter_order=3, - highpass_butter_wn=0.01, - ), -) - -qm_params = { - "presence_ratio": {"bin_duration_s": 60}, - "snr": {"peak_sign": "neg", "peak_mode": "extremum", "random_chunk_kwargs_dict": None}, - "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0}, - "rp_violation": {"refractory_period_ms": 1, "censored_period_ms": 0.0}, - "sliding_rp_violation": { - "bin_size_ms": 0.25, - "window_size_s": 1, - "exclude_ref_period_below_ms": 0.5, - "max_ref_period_ms": 10, - "contamination_values": None, - }, - "amplitude_cutoff": { - "peak_sign": "neg", - "num_histogram_bins": 100, - "histogram_smoothing_value": 3, - "amplitudes_bins_min_ratio": 5, - }, - "amplitude_median": {"peak_sign": "neg"}, - "nearest_neighbor": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4}, - "nn_isolation": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, - "nn_noise_overlap": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, -} -qm_metric_names = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violation", - "rp_violation", - "sliding_rp_violation", - "amplitude_cutoff", - "drift", - "isolation_distance", - "l_ratio", - "d_prime", -] - -sparsity_params = dict(method="radius", radius_um=100) - -postprocessing_params = dict( - sparsity=sparsity_params, - waveforms_deduplicate=dict( - ms_before=0.5, - ms_after=1.5, - max_spikes_per_unit=100, - return_scaled=False, - dtype=None, - precompute_template=("average",), - use_relative_path=True, - ), - waveforms=dict( - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - return_scaled=True, - dtype=None, - precompute_template=("average", "std"), - use_relative_path=True, - ), - spike_amplitudes=dict( - peak_sign="neg", - return_scaled=True, - outputs="concatenated", - ), - similarity=dict(method="cosine_similarity"), - correlograms=dict( - window_ms=100.0, - bin_ms=2.0, - ), - isis=dict( - window_ms=100.0, - bin_ms=5.0, - ), - locations=dict(method="monopolar_triangulation"), - template_metrics=dict(upsampling_factor=10, sparsity=None), - principal_components=dict(n_components=5, mode="by_channel_local", whiten=True), - quality_metrics=dict(qm_params=qm_params, metric_names=qm_metric_names, n_jobs=1), -) - -curation_params = dict( - duplicate_threshold=0.9, - isi_violations_ratio_threshold=0.5, - presence_ratio_threshold=0.8, - amplitude_cutoff_threshold=0.1, -) - -visualization_params = dict( - timeseries=dict(n_snippets_per_segment=2, snippet_duration_s=0.5, skip=False), - drift=dict( - detection=dict(method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), - localization=dict(ms_before=0.1, ms_after=0.3, local_radius_um=100.0), - n_skip=30, - alpha=0.15, - vmin=-200, - vmax=0, - cmap="Greys_r", - figsize=(10, 10), - ), -) - -data_folder = Path("data/") -scratch_folder = Path("scratch/") -results_folder = Path("results/") - - def main( - run_identifier: str = None, - source: str = None, - source_data_type: str = None, - source_data_paths: dict = None, - recording_kwargs: dict = None, - output_destination: str = None, - output_path: str = None, - sorter_name: str = None, - sorter_kwargs: dict = None, - concatenate_segments: bool = None, + run_at: str, + run_identifier: str, + run_description: str, test_with_toy_recording: bool = None, test_with_subrecording: bool = None, test_subrecording_n_frames: int = None, log_to_file: bool = None, + source_name: str = None, + source_data_type: str = None, + source_data_paths: dict = None, + recording_kwargs: dict = None, + preprocessing_kwargs: dict = None, + sorter_kwargs: dict = None, + postprocessing_kwargs: dict = None, + curation_kwargs: dict = None, + visualization_kwargs: dict = None, + output_destination: str = None, + output_path: str = None ): """ This script should run in an ephemeral Docker container and will: 1. download a dataset with raw electrophysiology traces from a specfied location - 2. run a SpikeInterface pipeline on the raw traces - 3. save the results in a target S3 bucket - - The arguments for this script can be passsed as ENV variables: - - RUN_IDENTIFIER : Unique identifier for this run. - - SOURCE : Source of input data. Choose from: local, s3, dandi. - - SOURCE_DATA_PATHS : Dictionary with paths to source data. Keys are names of data files, values are urls. - - SOURCE_DATA_TYPE : Data type to be read. Choose from: nwb, spikeglx. - - RECORDING_KWARGS : SpikeInterface extractor keyword arguments, specific to chosen dataset type. - - OUTPUT_DESTINATION : Destination for saving results. Choose from: local, s3, dandi. - - OUTPUT_PATH : Path for saving results. - If S3, should be a valid S3 path, E.g. s3://... - If local, should be a valid local path, E.g. /data/results - If dandi, should be a valid Dandiset uri, E.g. https://dandiarchive.org/dandiset/000001 - - SORTERS_NAME : Name of sorter to run on source data. - - SORTERS_KWARGS : Parameters for the sorter, stored as a dictionary. - - CONCATENATE_SEGMENTS : If True, concatenates all segments of the recording into one. - - TEST_WITH_TOY_RECORDING : Runs script with a toy dataset. - - TEST_WITH_SUB_RECORDING : Runs script with the first 4 seconds of target dataset. - - TEST_SUB_RECORDING_N_FRAMES : Number of frames to use for sub-recording. - - LOG_TO_FILE : If True, logs will be saved to a file in /logs folder. + 2. run a SpikeInterface pipeline, including preprocessing, spike sorting, postprocessing and curation + 3. save the results in a target S3 bucket or DANDI archive If running this in any AWS service (e.g. Batch, ECS, EC2...) the access to other AWS services such as S3 storage can be given to the container by an IAM role. @@ -228,70 +78,61 @@ def main( If saving results to DANDI archive, or reading from embargoed dandisets, the following ENV variables should be present in the running container: - DANDI_API_KEY - DANDI_API_KEY_STAGING - """ - - # Order of priority for definition of running arguments: - # 1. passed by function - # 2. retrieved from ENV vars - # 3. default value - if not run_identifier: - run_identifier = os.environ.get("RUN_IDENTIFIER", datetime.now().strftime("%Y%m%d%H%M%S")) - if not source: - source = os.environ.get("SOURCE", None) - if source == "None": - source = None - if not source_data_paths: - source_data_paths = eval(os.environ.get("SOURCE_DATA_PATHS", "{}")) - if not source_data_type: - source_data_type = os.environ.get("SOURCE_DATA_TYPE", "nwb") - if not recording_kwargs: - recording_kwargs = ast.literal_eval(os.environ.get("RECORDING_KWARGS", "{}")) - if not output_destination: - output_destination = os.environ.get("OUTPUT_DESTINATION", "s3") - if not output_path: - output_path = os.environ.get("OUTPUT_PATH", None) - if output_path == "None": - output_path = None - if not sorter_name: - sorter_name = os.environ.get("SORTER_NAME", "kilosort2.5") - if not sorter_kwargs: - sorter_kwargs = eval(os.environ.get("SORTER_KWARGS", "{}")) - if not concatenate_segments: - concatenate_segments_str = os.environ.get("CONCATENATE_SEGMENTS", "False") - concatenate_segments = True if concatenate_segments_str == "True" else False - if test_with_toy_recording is None: - test_with_toy_recording = os.environ.get("TEST_WITH_TOY_RECORDING", "False").lower() in ("true", "1", "t") - if test_with_subrecording is None: - test_with_subrecording = os.environ.get("TEST_WITH_SUB_RECORDING", "False").lower() in ("true", "1", "t") - if not test_subrecording_n_frames: - test_subrecording_n_frames = int(os.environ.get("TEST_SUBRECORDING_N_FRAMES", 300000)) - if log_to_file is None: - log_to_file = os.environ.get("LOG_TO_FILE", "False").lower() in ("true", "1", "t") + Parameters + ---------- + run_at : str + Where to run the sorting job. Choose from: aws, local. + run_identifier : str + Unique identifier for this run. + run_description : str + Description for this run. + test_with_toy_recording : bool + If True, runs script with a toy dataset. + test_with_subrecording : bool + If True, runs script with a subrecording of the source dataset. + test_subrecording_n_frames : int + Number of frames to use for sub-recording. + log_to_file : bool + If True, logs will be saved to a file in /logs folder. + source_name : str + Source of input data. Choose from: local, s3, dandi. + source_data_type : str + Data type to be read. Choose from: nwb, spikeglx. + source_data_paths : dict + Dictionary with paths to source data. Keys are names of data files, values are urls or paths. + recording_kwargs : dict + SpikeInterface recording keyword arguments, specific to chosen dataset type. + preprocessing_kwargs : dict + SpikeInterface preprocessing keyword arguments. + sorter_kwargs : dict + SpikeInterface sorter keyword arguments. + postprocessing_kwargs : dict + SpikeInterface postprocessing keyword arguments. + curation_kwargs : dict + SpikeInterface curation keyword arguments. + visualization_kwargs : dict + SpikeInterface visualization keyword arguments. + output_destination : str + Destination for saving results. Choose from: local, s3, dandi. + output_path : str + Path for saving results. + If S3, should be a valid S3 path, E.g. s3://... + If local, should be a valid local path, E.g. /data/results + If dandi, should be a valid Dandiset uri, E.g. https://dandiarchive.org/dandiset/000001 + """ # Set up logging logger = make_logger(run_identifier=run_identifier, log_to_file=log_to_file) filterwarnings(action="ignore", message="No cached namespaces found in .*") filterwarnings(action="ignore", message="Ignoring cached namespace .*") - # SET DEFAULT JOB KWARGS - si.set_global_job_kwargs(**job_kwargs) - - # Create folders - data_folder.mkdir(exist_ok=True) - scratch_folder.mkdir(exist_ok=True) - results_folder.mkdir(exist_ok=True) - tmp_folder = scratch_folder / "tmp" - if tmp_folder.is_dir(): - shutil.rmtree(tmp_folder) - tmp_folder.mkdir() - # Checks - if source not in ["local", "s3", "dandi"]: - logger.error(f"Source {source} not supported. Choose from: local, s3, dandi.") - raise ValueError(f"Source {source} not supported. Choose from: local, s3, dandi.") + if source_name not in ["local", "s3", "dandi"]: + logger.error(f"Source {source_name} not supported. Choose from: local, s3, dandi.") + raise ValueError(f"Source {source_name} not supported. Choose from: local, s3, dandi.") - # TODO: here we could leverage spikeinterface and add more options + # TODO: here we could eventually leverage spikeinterface and add more options if source_data_type not in ["nwb", "spikeglx"]: logger.error(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.") raise ValueError(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.") @@ -312,7 +153,28 @@ def main( output_s3_bucket = output_path_parsed.split("/")[0] output_s3_bucket_folder = "/".join(output_path_parsed.split("/")[1:]) - s3_client = boto3.client("s3") + # Set SpikeInterface global job kwargs + si.set_global_job_kwargs( + n_jobs=os.cpu_count(), + chunk_duration="1s", + progress_bar=False + ) + + # Create SpikeInterface folders + data_folder = Path("data/") + data_folder.mkdir(exist_ok=True) + scratch_folder = Path("scratch/") + scratch_folder.mkdir(exist_ok=True) + results_folder = Path("results/") + results_folder.mkdir(exist_ok=True) + tmp_folder = scratch_folder / "tmp" + if tmp_folder.is_dir(): + shutil.rmtree(tmp_folder) + tmp_folder.mkdir() + + # S3 client + if source_name == "s3" or output_destination == "s3": + s3_client = boto3.client("s3") # Test with toy recording if test_with_toy_recording: @@ -322,7 +184,7 @@ def main( recording_name = "toy" # Load data from S3 - elif source == "s3": + elif source_name == "s3": for k, data_url in source_data_paths.items(): if not data_url.startswith("s3://"): logger.error(f"Data url {data_url} is not a valid S3 path. E.g. s3://...") @@ -336,7 +198,6 @@ def main( bucket_name=bucket_name, file_path=file_path, ) - logger.info("Reading recording...") # E.g.: se.read_spikeglx(folder_path="/data", stream_id="imec.ap") if source_data_type == "spikeglx": @@ -345,17 +206,16 @@ def main( recording = se.read_nwb_recording(file_path=f"/data/{file_name}", **recording_kwargs) recording_name = "recording_on_s3" - elif source == "dandi": + # Load data from DANDI archive + elif source_name == "dandi": dandiset_s3_file_url = source_data_paths["file"] if not dandiset_s3_file_url.startswith("https://dandiarchive"): raise Exception( f"DANDISET_S3_FILE_URL should be a valid Dandiset S3 url. Value received was: {dandiset_s3_file_url}" ) - if not test_with_subrecording: logger.info(f"Downloading dataset: {dandiset_s3_file_url}") download_file_from_url(dandiset_s3_file_url) - logger.info("Reading recording from NWB...") recording = se.read_nwb_recording(file_path="/data/filename.nwb", **recording_kwargs) else: @@ -363,12 +223,18 @@ def main( recording = se.read_nwb_recording(file_path=dandiset_s3_file_url, stream_mode="fsspec", **recording_kwargs) recording_name = "recording_on_dandi" + # TODO - Load data from local files + elif source_name == "local": + pass + + # Run with subrecording if test_with_subrecording: n_frames = int(min(test_subrecording_n_frames, recording.get_num_frames())) recording = recording.frame_slice(start_frame=0, end_frame=n_frames) # ------------------------------------------------------------------------------------ # Preprocessing + # ------------------------------------------------------------------------------------ logger.info("Starting preprocessing...") preprocessing_notes = "" preprocessed_folder = tmp_folder / "preprocessed" @@ -376,13 +242,13 @@ def main( logger.info(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") if "inter_sample_shift" in recording.get_property_keys(): - recording_ps_full = spre.phase_shift(recording, **preprocessing_params["phase_shift"]) + recording_ps_full = spre.phase_shift(recording, **preprocessing_kwargs["phase_shift"]) else: recording_ps_full = recording - recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_params["highpass_filter"]) + recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_kwargs["highpass_filter"]) # IBL bad channel detection - _, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_params["detect_bad_channels"]) + _, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_kwargs["detect_bad_channels"]) dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" out_channel_mask = channel_labels == "out" @@ -395,7 +261,7 @@ def main( out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) - max_bad_channel_fraction_to_remove = preprocessing_params["max_bad_channel_fraction_to_remove"] + max_bad_channel_fraction_to_remove = preprocessing_kwargs["max_bad_channel_fraction_to_remove"] if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): logger.info( f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " @@ -405,28 +271,28 @@ def main( # in this case, we don't bother sorting return else: - if preprocessing_params["remove_out_channels"]: + if preprocessing_kwargs["remove_out_channels"]: logger.info(f"\tRemoving {len(out_channel_ids)} out channels") recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." else: recording_rm_out = recording_hp_full - recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_params["common_reference"]) + recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_kwargs["common_reference"]) bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) recording_hp_spatial = spre.highpass_spatial_filter( - recording_interp, **preprocessing_params["highpass_spatial_filter"] + recording_interp, **preprocessing_kwargs["highpass_spatial_filter"] ) - preproc_strategy = preprocessing_params["preprocessing_strategy"] + preproc_strategy = preprocessing_kwargs["preprocessing_strategy"] if preproc_strategy == "cmr": recording_processed = recording_processed_cmr else: recording_processed = recording_hp_spatial - if preprocessing_params["remove_bad_channels"]: + if preprocessing_kwargs["remove_bad_channels"]: logger.info(f"\tRemoving {len(bad_channel_ids)} channels after {preproc_strategy} preprocessing") recording_processed = recording_processed.remove_channels(bad_channel_ids) preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" @@ -438,6 +304,8 @@ def main( # ------------------------------------------------------------------------------------ # Spike Sorting + # ------------------------------------------------------------------------------------ + sorter_name = sorter_kwargs["sorter_name"] logger.info(f"\n\nStarting spike sorting with {sorter_name}") spikesorting_notes = "" sorting_params = None @@ -451,7 +319,7 @@ def main( if recording_processed.get_num_segments() > 1: recording_processed = si.concatenate_recordings([recording_processed]) - # run ks2.5 + # Run sorter try: sorting = ss.run_sorter( sorter_name, @@ -472,6 +340,7 @@ def main( # remove empty units sorting = sorting.remove_empty_units() + # remove spikes beyond num_samples (if any) sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording_processed) logger.info(f"\tSorting output without empty units: {sorting}") @@ -491,6 +360,7 @@ def main( # ------------------------------------------------------------------------------------ # Postprocessing + # ------------------------------------------------------------------------------------ logger.info("\n\Starting postprocessing...") postprocessing_notes = "" t_postprocessing_start = time.perf_counter() @@ -498,10 +368,10 @@ def main( # first extract some raw waveforms in memory to deduplicate based on peak alignment wf_dedup_folder = tmp_folder / "postprocessed" / recording_name we_raw = si.extract_waveforms( - recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_params["waveforms_deduplicate"] + recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_kwargs["waveforms_deduplicate"] ) # de-duplication - sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_params["duplicate_threshold"]) + sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_kwargs["duplicate_threshold"]) logger.info( f"\tNumber of original units: {len(we_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}" ) @@ -510,7 +380,7 @@ def main( ) deduplicated_unit_ids = sorting_deduplicated.unit_ids # use existing deduplicated waveforms to compute sparsity - sparsity_raw = si.compute_sparsity(we_raw, **sparsity_params) + sparsity_raw = si.compute_sparsity(we_raw, **postprocessing_kwargs["sparsity"]) sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] sparsity = si.ChannelSparsity(mask=sparsity_mask, unit_ids=deduplicated_unit_ids, channel_ids=recording.channel_ids) shutil.rmtree(wf_dedup_folder) @@ -527,40 +397,42 @@ def main( sparsity=sparsity, sparse=True, overwrite=True, - **postprocessing_params["waveforms"], + **postprocessing_kwargs["waveforms"], ) logger.info("\tComputing spike amplitides") - spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_params["spike_amplitudes"]) + spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_kwargs["spike_amplitudes"]) logger.info("\tComputing unit locations") - unit_locations = spost.compute_unit_locations(we, **postprocessing_params["locations"]) + unit_locations = spost.compute_unit_locations(we, **postprocessing_kwargs["locations"]) logger.info("\tComputing spike locations") - spike_locations = spost.compute_spike_locations(we, **postprocessing_params["locations"]) + spike_locations = spost.compute_spike_locations(we, **postprocessing_kwargs["locations"]) logger.info("\tComputing correlograms") - ccg, bins = spost.compute_correlograms(we, **postprocessing_params["correlograms"]) + ccg, bins = spost.compute_correlograms(we, **postprocessing_kwargs["correlograms"]) logger.info("\tComputing ISI histograms") - isi, bins = spost.compute_isi_histograms(we, **postprocessing_params["isis"]) + isi, bins = spost.compute_isi_histograms(we, **postprocessing_kwargs["isis"]) logger.info("\tComputing template similarity") - sim = spost.compute_template_similarity(we, **postprocessing_params["similarity"]) + sim = spost.compute_template_similarity(we, **postprocessing_kwargs["similarity"]) logger.info("\tComputing template metrics") - tm = spost.compute_template_metrics(we, **postprocessing_params["template_metrics"]) + tm = spost.compute_template_metrics(we, **postprocessing_kwargs["template_metrics"]) logger.info("\tComputing PCA") - pca = spost.compute_principal_components(we, **postprocessing_params["principal_components"]) + pca = spost.compute_principal_components(we, **postprocessing_kwargs["principal_components"]) logger.info("\tComputing quality metrics") - qm = sqm.compute_quality_metrics(we, **postprocessing_params["quality_metrics"]) + qm = sqm.compute_quality_metrics(we, **postprocessing_kwargs["quality_metrics"]) t_postprocessing_end = time.perf_counter() elapsed_time_postprocessing = np.round(t_postprocessing_end - t_postprocessing_start, 2) logger.info(f"Postprocessing time: {elapsed_time_postprocessing}s") - ###### CURATION ############## + # ------------------------------------------------------------------------------------ + # Curation + # ------------------------------------------------------------------------------------ logger.info("\n\Starting curation...") curation_notes = "" t_curation_start = time.perf_counter() # curation query - isi_violations_ratio_thr = curation_params["isi_violations_ratio_threshold"] - presence_ratio_thr = curation_params["presence_ratio_threshold"] - amplitude_cutoff_thr = curation_params["amplitude_cutoff_threshold"] + isi_violations_ratio_thr = curation_kwargs["isi_violations_ratio_threshold"] + presence_ratio_thr = curation_kwargs["presence_ratio_threshold"] + amplitude_cutoff_thr = curation_kwargs["amplitude_cutoff_threshold"] curation_query = f"isi_violations_ratio < {isi_violations_ratio_thr} and presence_ratio > {presence_ratio_thr} and amplitude_cutoff < {amplitude_cutoff_thr}" logger.info(f"Curation query: {curation_query}") @@ -590,10 +462,15 @@ def main( elapsed_time_curation = np.round(t_curation_end - t_curation_start, 2) logger.info(f"Curation time: {elapsed_time_curation}s") + # ------------------------------------------------------------------------------------ # TODO: Visualization with FIGURL (needs credentials) + # ------------------------------------------------------------------------------------ + + # ------------------------------------------------------------------------------------ # Conversion and upload + # ------------------------------------------------------------------------------------ logger.info("Writing sorting results to NWB...") metadata = { "NWBFile": { @@ -612,7 +489,7 @@ def main( results_nwb_folder.mkdir(parents=True, exist_ok=True) output_nwbfile_path = results_nwb_folder / f"{run_identifier}.nwb" - # TODO: Condider writing waveforms instead of sorting + # TODO: Consider writing waveforms instead of sorting # add sorting properties # unit locations sorting.set_property("unit_locations", unit_locations) @@ -706,7 +583,88 @@ def main( if __name__ == "__main__": - main() + # Get run kwargs from ENV variables + run_kwargs = json.loads(os.environ.get("SI_RUN_KWARGS", "{}")) + run_at = validate_not_none(run_kwargs, "run_at") + run_identifier = run_kwargs.get("run_identifier", datetime.now().strftime("%Y%m%d%H%M%S")) + run_description = run_kwargs.get("run_description", "") + test_with_toy_recording = run_kwargs.get("test_with_toy_recording", "False") + test_with_subrecording = run_kwargs.get("test_with_subrecording", "False") + test_subrecording_n_frames = int(run_kwargs.get("test_subrecording_n_frames", 30000)) + log_to_file = run_kwargs.get("log_to_file", "False") + + # Get source data kwargs from ENV variables + source_data_kwargs = json.loads(os.environ.get("SI_SOURCE_DATA_KWARGS", "{}")) + source_name = validate_not_none(source_data_kwargs, "source_name") + source_data_type = validate_not_none(source_data_kwargs, "source_data_type") + source_data_paths = validate_not_none(source_data_kwargs, "source_data_paths") + + # Get recording kwargs from ENV variables + recording_kwargs = json.loads(os.environ.get("SI_RECORDING_KWARGS", "{}")) + + # Get preprocessing kwargs from ENV variables + preprocessing_kwargs = json.loads(os.environ.get("SI_PREPROCESSING_KWARGS", "{}")) + + # Get sorter kwargs from ENV variables + sorter_kwargs = json.loads(os.environ.get("SI_SORTER_KWARGS", "{}")) + + # Get postprocessing kwargs from ENV variables + postprocessing_kwargs = json.loads(os.environ.get("SI_POSTPROCESSING_KWARGS", "{}")) + + # Get curation kwargs from ENV variables + curation_kwargs = json.loads(os.environ.get("SI_CURATION_KWARGS", "{}")) + + # Get visualization kwargs from ENV variables + visualization_kwargs = json.loads(os.environ.get("SI_VISUALIZATION_KWARGS", "{}")) + + # Get output kwargs from ENV variables + output_kwargs = json.loads(os.environ.get("SI_OUTPUT_DATA_KWARGS", "{}")) + output_destination = validate_not_none(output_kwargs, "output_destination") + output_path = validate_not_none(output_kwargs, "output_path") + + # Just for checking for now - REMOVE LATER + print("\nRun at: ", run_at) + print("\nRun identifier: ", run_identifier) + print("\nRun description: ", run_description) + print("\nTest with toy recording: ", test_with_toy_recording) + print("\nTest with subrecording: ", test_with_subrecording) + print("\nTest subrecording n frames: ", test_subrecording_n_frames) + print("\nLog to file: ", log_to_file) + print("\nSource name: ", source_name) + print("\nSource data type: ", source_data_type) + print("\nSource data paths: ", source_data_paths) + print("\nRecording kwargs: ", recording_kwargs) + print("\nPreprocessing kwargs: ", preprocessing_kwargs) + print("\nSorter kwargs: ", sorter_kwargs) + print("\nPostprocessing kwargs: ", postprocessing_kwargs) + print("\nCuration kwargs: ", curation_kwargs) + print("\nVisualization kwargs: ", visualization_kwargs) + print("\nOutput destination: ", output_destination) + print("\nOutput path: ", output_path) + + # Run main function + main( + run_at=run_at, + run_identifier=run_identifier, + run_description=run_description, + test_with_toy_recording=test_with_toy_recording, + test_with_subrecording=test_with_subrecording, + test_subrecording_n_frames=test_subrecording_n_frames, + log_to_file=log_to_file, + source_name=source_name, + source_data_type=source_data_type, + source_data_paths=source_data_paths, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, + output_destination=output_destination, + output_path=output_path, + ) + + # Known issues: # diff --git a/containers/source_format.py b/containers/source_format.py deleted file mode 100644 index 216717f..0000000 --- a/containers/source_format.py +++ /dev/null @@ -1,23 +0,0 @@ - - -source_data = { - "source": "dandi", # or "s3" - "source_data_type": "nwb", # or "spikeglx" - "source_data_paths": { - "file": "https://dandiarchive.org/dandiset/000003/0.210813.1807" - }, - "recording_kwargs": { - "electrical_series_name": "ElectricalSeries", - } -} - - -source_data_2 = { - "source": "s3", - "source_data_type": "spikeglx", - "source_data_paths": { - "file_bin": "s3://bucket/path/to/file.ap.bin", - "file_meta": "s3://bucket/path/to/file2.ap.meta", - }, - "recording_kwargs": {} -} \ No newline at end of file diff --git a/containers/utils.py b/containers/utils.py index 9598648..60de1e0 100644 --- a/containers/utils.py +++ b/containers/utils.py @@ -56,6 +56,13 @@ def make_logger(run_identifier: str, log_to_file: bool): return logger +def validate_not_none(d, k): + v = d.get(k, None) + if v is None: + raise ValueError(f"{k} not specified.") + return v + + def download_file_from_url(url): # ref: https://stackoverflow.com/a/39217788/11483674 local_filename = "/data/filename.nwb" diff --git a/docker-compose-dev.yaml b/docker-compose-dev.yaml new file mode 100644 index 0000000..16e2fea --- /dev/null +++ b/docker-compose-dev.yaml @@ -0,0 +1,93 @@ +version: "3" + +services: + docker-proxy: + image: bobrik/socat + container_name: si-docker-proxy + command: "TCP4-LISTEN:2375,fork,reuseaddr UNIX-CONNECT:/var/run/docker.sock" + ports: + - "2376:2375" + volumes: + - /var/run/docker.sock:/var/run/docker.sock + + # frontend: + # build: + # context: frontend + # dockerfile: Dockerfile + # image: si-sorting-frontend + # container_name: si-sorting-frontend + # command: ["npm", "run", "start"] + # ports: + # - "5173:5173" + # environment: + # DEPLOY_MODE: compose + # volumes: + # - ./frontend:/app + # depends_on: + # - rest + + rest: + build: + context: rest + dockerfile: Dockerfile + image: si-sorting-rest + container_name: si-sorting-rest + command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload", "--reload-dir", "/app"] + ports: + - "8000:8000" + environment: + REST_DEPLOY_MODE: compose + WORKER_DEPLOY_MODE: compose + AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} + AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} + AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} + AWS_BATCH_JOB_QUEUE: ${AWS_BATCH_JOB_QUEUE} + AWS_BATCH_JOB_DEFINITION: ${AWS_BATCH_JOB_DEFINITION} + DANDI_API_KEY: ${DANDI_API_KEY} + volumes: + - ./rest:/app + depends_on: + - database + + # worker: + # build: + # context: containers + # dockerfile: Dockerfile.combined + # image: si-sorting-worker + # # image: ghcr.io/catalystneuro/si-sorting-worker:latest + # container_name: si-sorting-worker + # ports: + # - "5000:5000" + # environment: + # WORKER_DEPLOY_MODE: compose + # AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} + # AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} + # AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} + # DANDI_API_KEY: ${DANDI_API_KEY} + # DANDI_API_KEY_STAGING: ${DANDI_API_KEY_STAGING} + # volumes: + # - ./containers:/app + # - ./results:/results + # - ./logs:/logs + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] + + database: + image: postgres:latest + container_name: si-sorting-db + environment: + POSTGRES_USER: postgres + POSTGRES_PASSWORD: postgres + POSTGRES_DB: si-sorting-db + volumes: + - pgdata:/var/lib/postgresql/data + ports: + - "5432:5432" + +volumes: + pgdata: diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml deleted file mode 100644 index 4df075c..0000000 --- a/docker-compose-dev.yml +++ /dev/null @@ -1,83 +0,0 @@ -version: "3" - -services: - frontend: - build: - context: frontend - dockerfile: Dockerfile - image: si-sorting-frontend - container_name: si-sorting-frontend - command: ["npm", "run", "start"] - ports: - - "5173:5173" - environment: - DEPLOY_MODE: compose - volumes: - - ./frontend:/app - depends_on: - - rest - - rest: - build: - context: rest - dockerfile: Dockerfile - image: si-sorting-rest - container_name: si-sorting-rest - ports: - - "8000:8000" - environment: - REST_DEPLOY_MODE: compose - WORKER_DEPLOY_MODE: compose - AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} - AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} - AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} - AWS_BATCH_JOB_QUEUE: ${AWS_BATCH_JOB_QUEUE} - AWS_BATCH_JOB_DEFINITION: ${AWS_BATCH_JOB_DEFINITION} - DANDI_API_KEY: ${DANDI_API_KEY} - volumes: - - ./rest:/app - depends_on: - - database - - worker: - build: - context: containers - dockerfile: Dockerfile.combined - image: si-sorting-worker - # image: ghcr.io/catalystneuro/si-sorting-worker:latest - container_name: si-sorting-worker - ports: - - "5000:5000" - environment: - WORKER_DEPLOY_MODE: compose - AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} - AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} - AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} - DANDI_API_KEY: ${DANDI_API_KEY} - DANDI_API_KEY_STAGING: ${DANDI_API_KEY_STAGING} - volumes: - - ./containers:/app - - ./results:/results - - ./logs:/logs - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] - - database: - image: postgres:latest - container_name: si-sorting-db - environment: - POSTGRES_USER: postgres - POSTGRES_PASSWORD: postgres - POSTGRES_DB: si-sorting-db - volumes: - - pgdata:/var/lib/postgresql/data - ports: - - "5432:5432" - -volumes: - pgdata: diff --git a/docker-compose.yml b/docker-compose.yaml similarity index 100% rename from docker-compose.yml rename to docker-compose.yaml diff --git a/rest/Dockerfile b/rest/Dockerfile index df71593..48d746c 100644 --- a/rest/Dockerfile +++ b/rest/Dockerfile @@ -19,4 +19,7 @@ ENV SI_CLOUD_ENV production ENV PYTHONUNBUFFERED=1 EXPOSE 8000 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] \ No newline at end of file + +WORKDIR / + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/rest/README.md b/rest/README.md new file mode 100644 index 0000000..3889cdf --- /dev/null +++ b/rest/README.md @@ -0,0 +1,7 @@ +# REST API + +To run REST API in local environment, from the root directory of the project run: + +```bash +uvicorn rest.main:app --host 0.0.0.0 --port 8000 --workers 4 --reload +``` \ No newline at end of file diff --git a/rest/__init__.py b/rest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/clients/aws.py b/rest/clients/aws.py index 04f89c1..d347e2a 100644 --- a/rest/clients/aws.py +++ b/rest/clients/aws.py @@ -1,7 +1,7 @@ import boto3 import enum -from core import settings +from ..core import settings class JobStatus(enum.Enum): diff --git a/rest/clients/database.py b/rest/clients/database.py index 48689e1..a552525 100644 --- a/rest/clients/database.py +++ b/rest/clients/database.py @@ -4,7 +4,7 @@ import ast import json -from db.models import User, DataSource, Run +from ..db.models import User, DataSource, Run class DatabaseClient: diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py new file mode 100644 index 0000000..f2d97fa --- /dev/null +++ b/rest/clients/local_docker.py @@ -0,0 +1,96 @@ +from pathlib import Path +import docker +import uuid + +from ..core.logger import logger +from ..models.sorting import ( + RunKwargs, + SourceDataKwargs, + RecordingKwargs, + PreprocessingKwargs, + SorterKwargs, + PostprocessingKwargs, + CurationKwargs, + VisualizationKwargs, + OutputDataKwargs +) + + +map_sorter_to_image = { + "kilosort2": "ghcr.io/catalystneuro/si-sorting-ks2:latest", + "kilosort25": "ghcr.io/catalystneuro/si-sorting-ks25:latest", + "kilosort3": "ghcr.io/catalystneuro/si-sorting-ks3:latest", + "ironclust": "ghcr.io/catalystneuro/si-sorting-ironclust:latest", + "spykingcircus": "ghcr.io/catalystneuro/si-sorting-spyking-circus:latest", +} + + +class LocalDockerClient: + + def __init__(self, base_url: str = "tcp://docker-proxy:2375"): + self.logger = logger + self.client = docker.DockerClient(base_url=base_url) + + def run_sorting( + self, + run_kwargs: RunKwargs, + source_data_kwargs: SourceDataKwargs, + recording_kwargs: RecordingKwargs, + preprocessing_kwargs: PreprocessingKwargs, + sorter_kwargs: SorterKwargs, + postprocessing_kwargs: PostprocessingKwargs, + curation_kwargs: CurationKwargs, + visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, + ) -> None: + # Pass kwargs as environment variables to the container + env_vars = dict( + SI_RUN_KWARGS=run_kwargs.json(), + SI_SOURCE_DATA_KWARGS=source_data_kwargs.json(), + SI_RECORDING_KWARGS=recording_kwargs.json(), + SI_PREPROCESSING_KWARGS=preprocessing_kwargs.json(), + SI_SORTER_KWARGS=sorter_kwargs.json(), + SI_POSTPROCESSING_KWARGS=postprocessing_kwargs.json(), + SI_CURATION_KWARGS=curation_kwargs.json(), + SI_VISUALIZATION_KWARGS=visualization_kwargs.json(), + SI_OUTPUT_DATA_KWARGS=output_data_kwargs.json(), + ) + + # Local volumes to mount + local_directory = Path(".").absolute() + logs_directory = local_directory / "logs" + results_directory = local_directory / "results" + volumes = { + logs_directory: {'bind': '/logs', 'mode': 'rw'}, + results_directory: {'bind': '/results', 'mode': 'rw'}, + } + + container = self.client.containers.run( + name=f'si-sorting-run-{run_kwargs.run_identifier}-{uuid.uuid4().hex[:6]}', + image=map_sorter_to_image[sorter_kwargs.sorter_name], + command=['python', 'main.py'], + detach=True, + environment=env_vars, + volumes=volumes, + device_requests=[ + docker.types.DeviceRequest( + device_ids=["0"], + capabilities=[['gpu']] + ) + ] + ) + + def get_run_logs(self, run_identifier): + # TODO: Implement this + self.logger.info("Getting logs...") + # response = requests.get(self.url + "/logs", params={"run_identifier": run_identifier}) + # if response.status_code == 200: + # logs = response.content.decode('utf-8') + # if "Error running sorter" in logs: + # return "fail", logs + # elif "Sorting job completed successfully!" in logs: + # return "success", logs + # return "running", logs + # else: + # self.logger.info(f"Error {response.status_code}: {response.content}") + # return "fail", f"Logs couldn't be retrieved. Error {response.status_code}: {response.content}" \ No newline at end of file diff --git a/rest/clients/local_worker.py b/rest/clients/local_worker_deprecated.py similarity index 100% rename from rest/clients/local_worker.py rename to rest/clients/local_worker_deprecated.py diff --git a/rest/core/__init__.py b/rest/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/data/__init__.py b/rest/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/db/utils.py b/rest/db/utils.py index 4e7666c..097df5b 100644 --- a/rest/db/utils.py +++ b/rest/db/utils.py @@ -2,7 +2,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker -from db.models import Base, User, DataSource, Run +from .models import Base, User, DataSource, Run def initialize_db(db: str): diff --git a/rest/main.py b/rest/main.py index 4ed5098..c2959a1 100644 --- a/rest/main.py +++ b/rest/main.py @@ -5,13 +5,13 @@ from fastapi.responses import JSONResponse from pathlib import Path -from core.settings import settings -from routes.user import router as router_user -from routes.dandi import router as router_dandi -from routes.sorting import router as router_sorting -from routes.runs import router as router_runs -from clients.dandi import DandiClient -from db.utils import initialize_db +from .core.settings import settings +from .routes.user import router as router_user +from .routes.dandi import router as router_dandi +from .routes.sorting import router as router_sorting +from .routes.runs import router as router_runs +from .clients.dandi import DandiClient +from .db.utils import initialize_db import logging diff --git a/rest/models/sorting.py b/rest/models/sorting.py index c43709e..c85549d 100644 --- a/rest/models/sorting.py +++ b/rest/models/sorting.py @@ -1,43 +1,284 @@ -from pydantic import BaseModel -from typing import List +from pydantic import BaseModel, Field, Extra +from typing import Optional, Dict, List, Union, Tuple from enum import Enum +from datetime import datetime -class OutputDestination(str, Enum): - s3 = "s3" - dandi = "dandi" +# ------------------------------ +# Run Models +# ------------------------------ +class RunAt(str, Enum): + aws = "aws" local = "local" +def default_run_identifier(): + return datetime.now().strftime("%Y%m%d-%H%M%S") + +class RunKwargs(BaseModel): + run_at: RunAt = Field(..., description="Where to run the sorting job. Choose from: aws, local.") + run_identifier: str = Field(default_factory=default_run_identifier, description="Unique identifier for the run.") + run_description: str = Field(default="", description="Description of the run.") + test_with_toy_recording: bool = Field(default=False, description="Whether to test with a toy recording.") + test_with_subrecording: bool = Field(default=False, description="Whether to test with a subrecording.") + test_subrecording_n_frames: Optional[int] = Field(default=30000, description="Number of frames to use for the subrecording.") + log_to_file: bool = Field(default=False, description="Whether to log to a file.") + + +# ------------------------------ +# Source Data Models +# ------------------------------ +class SourceName(str, Enum): + local = "local" + s3 = "s3" + dandi = "dandi" class SourceDataType(str, Enum): nwb = "nwb" spikeglx = "spikeglx" +class SourceDataKwargs(BaseModel): + source_name: SourceName = Field(..., description="Source of input data. Choose from: s3, dandi, local.") + source_data_type: SourceDataType = Field(..., description="Type of input data. Choose from: nwb, spikeglx.") + source_data_paths: Dict[str, str] = Field(..., description="Dictionary with paths to source data. Keys are names of data files, values are urls.") -class Source(str, Enum): + +# ------------------------------ +# Output Data Models +# ------------------------------ +class OutputDestination(str, Enum): + local = "local" s3 = "s3" dandi = "dandi" +class OutputDataKwargs(BaseModel): + output_destination: OutputDestination = Field(..., description="Destination of output data. Choose from: s3, dandi, local.") + output_path: str = Field(..., description="Path to output data.") -class RunAt(str, Enum): - aws = "aws" - local = "local" +# ------------------------------ +# Recording Models +# ------------------------------ +class RecordingKwargs(BaseModel, extra=Extra.allow): + pass + + +# ------------------------------ +# Preprocessing Models +# ------------------------------ +class HighpassFilter(BaseModel): + freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter") + margin_ms: float = Field(default=5.0, description="Margin in milliseconds") + +class PhaseShift(BaseModel): + margin_ms: float = Field(default=100.0, description="Margin in milliseconds for phase shift") + +class DetectBadChannels(BaseModel): + method: str = Field(default="coherence+psd", description="Method to detect bad channels") + dead_channel_threshold: float = Field(default=-0.5, description="Threshold for dead channel") + noisy_channel_threshold: float = Field(default=1.0, description="Threshold for noisy channel") + outside_channel_threshold: float = Field(default=-0.3, description="Threshold for outside channel") + n_neighbors: int = Field(default=11, description="Number of neighbors") + seed: int = Field(default=0, description="Seed value") + +class CommonReference(BaseModel): + reference: str = Field(default="global", description="Type of reference") + operator: str = Field(default="median", description="Operator used for common reference") + +class HighpassSpatialFilter(BaseModel): + n_channel_pad: int = Field(default=60, description="Number of channels to pad") + n_channel_taper: Optional[int] = Field(default=None, description="Number of channels to taper") + direction: str = Field(default="y", description="Direction for the spatial filter") + apply_agc: bool = Field(default=True, description="Whether to apply automatic gain control") + agc_window_length_s: float = Field(default=0.01, description="Window length in seconds for AGC") + highpass_butter_order: int = Field(default=3, description="Order for the Butterworth filter") + highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") + +class PreprocessingKwargs(BaseModel): + preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing") + highpass_filter: HighpassFilter + phase_shift: PhaseShift + detect_bad_channels: DetectBadChannels + remove_out_channels: bool = Field(default=False, description="Flag to remove out channels") + remove_bad_channels: bool = Field(default=False, description="Flag to remove bad channels") + max_bad_channel_fraction_to_remove: float = Field(default=1.1, description="Maximum fraction of bad channels to remove") + common_reference: CommonReference + highpass_spatial_filter: HighpassSpatialFilter + + +# ------------------------------ +# Sorter Models +# ------------------------------ +class SorterName(str, Enum): + ironclust = "ironclust" + kilosort2 = "kilosort2" + kilosort25 = "kilosort25" + kilosort3 = "kilosort3" + spykingcircus = "spykingcircus" + +class SorterKwargs(BaseModel, extra=Extra.allow): + sorter_name: SorterName = Field(..., description="Name of the sorter to use.") + + +# ------------------------------ +# Postprocessing Models +# ------------------------------ +class PresenceRatio(BaseModel): + bin_duration_s: float = Field(60, description="Duration of the bin in seconds.") + +class SNR(BaseModel): + peak_sign: str = Field("neg", description="Sign of the peak.") + peak_mode: str = Field("extremum", description="Mode of the peak.") + random_chunk_kwargs_dict: Optional[dict] = Field(None, description="Random chunk arguments.") + +class ISIViolation(BaseModel): + isi_threshold_ms: float = Field(1.5, description="ISI threshold in milliseconds.") + min_isi_ms: float = Field(0., description="Minimum ISI in milliseconds.") + +class RPViolation(BaseModel): + refractory_period_ms: float = Field(1., description="Refractory period in milliseconds.") + censored_period_ms: float = Field(0.0, description="Censored period in milliseconds.") + +class SlidingRPViolation(BaseModel): + min_spikes: int = Field(0, description="Contamination is set to np.nan if the unit has less than this many spikes across all segments.") + bin_size_ms: float = Field(0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25.") + window_size_s: float = Field(1, description="Window in seconds to compute correlogram, by default 1.") + exclude_ref_period_below_ms: float = Field(0.5, description="Refractory periods below this value are excluded, by default 0.5") + max_ref_period_ms: float = Field(10, description="Maximum refractory period to test in ms, by default 10 ms.") + contamination_values: Optional[list] = Field(None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %") + +class PeakSign(str, Enum): + neg = "neg" + pos = "pos" + both = "both" + +class AmplitudeCutoff(BaseModel): + peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + num_histogram_bins: int = Field(100, description="The number of bins to use to compute the amplitude histogram.") + histogram_smoothing_value: int = Field(3, description="Controls the smoothing applied to the amplitude histogram.") + amplitudes_bins_min_ratio: int = Field(5, description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.") + +class AmplitudeMedian(BaseModel): + peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + +class NearestNeighbor(BaseModel): + max_spikes: int = Field(10000, description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.") + min_spikes: int = Field(10, description="Minimum number of spikes.") + n_neighbors: int = Field(4, description="The number of neighbors to use.") + +class NNIsolation(NearestNeighbor): + n_components: int = Field(10, description="The number of PC components to use to project the snippets to.") + radius_um: int = Field(100, description="The radius, in um, that channels need to be within the peak channel to be included.") + +class QMParams(BaseModel): + presence_ratio: PresenceRatio + snr: SNR + isi_violation: ISIViolation + rp_violation: RPViolation + sliding_rp_violation: SlidingRPViolation + amplitude_cutoff: AmplitudeCutoff + amplitude_median: AmplitudeMedian + nearest_neighbor: NearestNeighbor + nn_isolation: NNIsolation + nn_noise_overlap: NNIsolation + +class QualityMetrics(BaseModel): + qm_params: QMParams = Field(..., description="Quality metric parameters.") + metric_names: List[str] = Field(..., description="List of metric names to compute.") + n_jobs: int = Field(1, description="Number of jobs.") + +class Sparsity(BaseModel): + method: str = Field("radius", description="Method for determining sparsity.") + radius_um: int = Field(100, description="Radius in micrometers for sparsity.") + +class Waveforms(BaseModel): + ms_before: float = Field(3.0, description="Milliseconds before") + ms_after: float = Field(4.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(500, description="Maximum spikes per unit") + return_scaled: bool = Field(True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field(("average", "std"), description="Precomputation template method") + use_relative_path: bool = Field(True, description="Use relative paths") + +class SpikeAmplitudes(BaseModel): + peak_sign: str = Field("neg", description="Sign of the peak") + return_scaled: bool = Field(True, description="Flag to determine if amplitudes should be scaled") + outputs: str = Field("concatenated", description="Output format for the spike amplitudes") + +class Similarity(BaseModel): + method: str = Field("cosine_similarity", description="Method to compute similarity") + +class Correlograms(BaseModel): + window_ms: float = Field(100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(2.0, description="Size of the bin in milliseconds") + +class ISIS(BaseModel): + window_ms: float = Field(100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(5.0, description="Size of the bin in milliseconds") + +class Locations(BaseModel): + method: str = Field("monopolar_triangulation", description="Method to determine locations") + +class TemplateMetrics(BaseModel): + upsampling_factor: int = Field(10, description="Upsampling factor") + sparsity: Optional[str] = Field(None, description="Sparsity method") + +class PrincipalComponents(BaseModel): + n_components: int = Field(5, description="Number of principal components") + mode: str = Field("by_channel_local", description="Mode of principal component analysis") + whiten: bool = Field(True, description="Whiten the components") + +class PostprocessingKwargs(BaseModel): + sparsity: Sparsity + waveforms_deduplicate: Waveforms + waveforms: Waveforms + spike_amplitudes: SpikeAmplitudes + similarity: Similarity + correlograms: Correlograms + isis: ISIS + locations: Locations + template_metrics: TemplateMetrics + principal_components: PrincipalComponents + quality_metrics: QualityMetrics + +# ------------------------------ +# Curation Models +# ------------------------------ +class CurationKwargs(BaseModel): + duplicate_threshold: float = Field(0.9, description="Threshold for duplicate units") + isi_violations_ratio_threshold: float = Field(0.5, description="Threshold for ISI violations ratio") + presence_ratio_threshold: float = Field(0.8, description="Threshold for presence ratio") + amplitude_cutoff_threshold: float = Field(0.1, description="Threshold for amplitude cutoff") + + +# ------------------------------ +# Visualization Models +# ------------------------------ +class Timeseries(BaseModel): + n_snippets_per_segment: int = Field(2, description="Number of snippets per segment") + snippet_duration_s: float = Field(0.5, description="Duration of the snippet in seconds") + skip: bool = Field(False, description="Flag to skip") + +class Detection(BaseModel): + method: str = Field("locally_exclusive", description="Method for detection") + peak_sign: str = Field("neg", description="Sign of the peak") + detect_threshold: int = Field(5, description="Detection threshold") + exclude_sweep_ms: float = Field(0.1, description="Exclude sweep in milliseconds") + +class Localization(BaseModel): + ms_before: float = Field(0.1, description="Milliseconds before") + ms_after: float = Field(0.3, description="Milliseconds after") + local_radius_um: float = Field(100.0, description="Local radius in micrometers") + +class Drift(BaseModel): + detection: Detection + localization: Localization + n_skip: int = Field(30, description="Number of skips") + alpha: float = Field(0.15, description="Alpha value") + vmin: int = Field(-200, description="Minimum value") + vmax: int = Field(0, description="Maximum value") + cmap: str = Field("Greys_r", description="Colormap") + figsize: Tuple[int, int] = Field((10, 10), description="Figure size") +class VisualizationKwargs(BaseModel): + timeseries: Timeseries + drift: Drift -class SortingData(BaseModel): - run_at: RunAt = "local" - run_identifier: str = None - run_description: str = None - source: Source = None - source_data_type: SourceDataType = None - source_data_paths: dict = None - subject_metadata: dict = None - recording_kwargs: dict = None - output_destination: OutputDestination = None - output_path: str = None - sorters_names_list: List[str] = None - sorters_kwargs: dict = None - test_with_toy_recording: bool = None - test_with_subrecording: bool = None - test_subrecording_n_frames: int = None - log_to_file: bool = None \ No newline at end of file diff --git a/rest/requirements.txt b/rest/requirements.txt index fb8d014..8fff453 100644 --- a/rest/requirements.txt +++ b/rest/requirements.txt @@ -1,8 +1,11 @@ -fastapi[all]==0.95.0 -dandi==0.56.0 +fastapi==0.103.2 +uvicorn==0.23.2 +pydantic==1.10.13 +dandi==0.56.2 fsspec==2023.3.0 requests==2.28.2 aiohttp==3.8.4 boto3==1.26.102 SQLAlchemy==2.0.8 -psycopg2==2.9.5 \ No newline at end of file +psycopg2==2.9.5 +docker==6.1.3 \ No newline at end of file diff --git a/rest/routes/__init__.py b/rest/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/routes/dandi.py b/rest/routes/dandi.py index 52395f7..035b06b 100644 --- a/rest/routes/dandi.py +++ b/rest/routes/dandi.py @@ -2,8 +2,8 @@ from fastapi.responses import JSONResponse from typing import List -from clients.dandi import DandiClient -from core.settings import settings +from ..clients.dandi import DandiClient +from ..core.settings import settings router = APIRouter() diff --git a/rest/routes/runs.py b/rest/routes/runs.py index a64eb61..e22e294 100644 --- a/rest/routes/runs.py +++ b/rest/routes/runs.py @@ -2,11 +2,11 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from core.logger import logger -from core.settings import settings -from clients.database import DatabaseClient -from clients.aws import AWSClient -from clients.local_worker import LocalWorkerClient +from ..core.logger import logger +from ..core.settings import settings +from ..clients.database import DatabaseClient +from ..clients.aws import AWSClient +from ..clients.local_docker import LocalDockerClient router = APIRouter() @@ -37,7 +37,8 @@ def get_run_info(run_id: str): if "Error running sorter" in run_logs: status = "fail" elif run_info["run_at"] == "local": - local_worker_client = LocalWorkerClient() + # TODO: Implement this + local_worker_client = LocalDockerClient() status, run_logs = local_worker_client.get_run_logs(run_identifier=run_info['identifier']) else: status = "running" diff --git a/rest/routes/sorting.py b/rest/routes/sorting.py index cc97a7d..30f8dac 100644 --- a/rest/routes/sorting.py +++ b/rest/routes/sorting.py @@ -2,28 +2,59 @@ from fastapi.responses import JSONResponse from datetime import datetime -from core.logger import logger -from core.settings import settings -from clients.dandi import DandiClient -from clients.aws import AWSClient -from clients.local_worker import LocalWorkerClient -from clients.database import DatabaseClient -from models.sorting import SortingData +from ..core.logger import logger +from ..core.settings import settings +from ..clients.dandi import DandiClient +from ..clients.aws import AWSClient +from ..clients.local_docker import LocalDockerClient +from ..clients.database import DatabaseClient +from ..models.sorting import ( + RunKwargs, + SourceDataKwargs, + RecordingKwargs, + PreprocessingKwargs, + SorterKwargs, + PostprocessingKwargs, + CurationKwargs, + VisualizationKwargs, + OutputDataKwargs +) router = APIRouter() -def sorting_background_task(payload, run_identifier): +def sorting_background_task( + run_kwargs: RunKwargs, + source_data_kwargs: SourceDataKwargs, + recording_kwargs: RecordingKwargs, + preprocessing_kwargs: PreprocessingKwargs, + sorter_kwargs: SorterKwargs, + postprocessing_kwargs: PostprocessingKwargs, + curation_kwargs: CurationKwargs, + visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, +): # Run sorting and update db entry status db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) - run_at = payload.get("run_at", None) + run_at = run_kwargs.run_at try: logger.info(f"Run job at: {run_at}") if run_at == "local": - client_local_worker = LocalWorkerClient() - client_local_worker.run_sorting(**payload) + client_local_worker = LocalDockerClient() + client_local_worker.run_sorting( + run_kwargs=run_kwargs, + source_data_kwargs=source_data_kwargs, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, + output_data_kwargs=output_data_kwargs, + ) elif run_at == "aws": + # TODO: Implement this job_kwargs = {k.upper(): v for k, v in payload.items()} job_kwargs["DANDI_API_KEY"] = settings.DANDI_API_KEY client_aws = AWSClient() @@ -33,49 +64,64 @@ def sorting_background_task(payload, run_identifier): job_definition=settings.AWS_BATCH_JOB_DEFINITION, job_kwargs=job_kwargs, ) - db_client.update_run(run_identifier=run_identifier, key="status", value="running") + # db_client.update_run(run_identifier=run_identifier, key="status", value="running") except Exception as e: - logger.exception(f"Error running sorting job: {run_identifier}.\n {e}") - db_client.update_run(run_identifier=run_identifier, key="status", value="fail") + logger.exception(f"Error running sorting job: {run_kwargs.run_identifier}.\n {e}") + db_client.update_run(run_identifier=run_kwargs.run_identifier, key="status", value="fail") @router.post("/run", response_description="Run Sorting", tags=["sorting"]) -async def route_run_sorting(data: SortingData, background_tasks: BackgroundTasks) -> JSONResponse: - if not data.run_identifier: - run_identifier = datetime.now().strftime("%Y%m%d%H%M%S") - else: - run_identifier = data.run_identifier +async def route_run_sorting( + run_kwargs: RunKwargs, + source_data_kwargs: SourceDataKwargs, + recording_kwargs: RecordingKwargs, + preprocessing_kwargs: PreprocessingKwargs, + sorter_kwargs: SorterKwargs, + postprocessing_kwargs: PostprocessingKwargs, + curation_kwargs: CurationKwargs, + visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, + background_tasks: BackgroundTasks +) -> JSONResponse: try: # Create Database entries db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) user = db_client.get_user_info(username="admin") - data_source = db_client.create_data_source( - name=data.run_identifier, - description=data.run_description, - user_id=user.id, - source=data.source, - source_data_type=data.source_data_type, - source_data_paths=str(data.source_data_paths), - recording_kwargs=str(data.recording_kwargs), - ) - run = db_client.create_run( - run_at=data.run_at, - identifier=run_identifier, - description=data.run_description, - last_run=datetime.now().strftime("%Y/%m/%d %H:%M:%S"), - status="running", - data_source_id=data_source.id, - user_id=user.id, - metadata=str(data.json()), - output_destination=data.output_destination, - output_path=data.output_path, - ) + # TODO: Create data source and run entries in database + # data_source = db_client.create_data_source( + # name=run_kwargs.run_identifier, + # description=run_kwargs.run_description, + # user_id=user.id, + # source=source_data_kwargs.source_name, + # source_data_type=source_data_kwargs.source_data_type, + # source_data_paths=str(source_data_kwargs.source_data_paths), + # recording_kwargs=str(recording_kwargs.dict()), + # ) + # run = db_client.create_run( + # run_at=data.run_at, + # identifier=run_identifier, + # description=data.run_description, + # last_run=datetime.now().strftime("%Y/%m/%d %H:%M:%S"), + # status="running", + # data_source_id=data_source.id, + # user_id=user.id, + # metadata=str(data.json()), + # output_destination=data.output_destination, + # output_path=data.output_path, + # ) # Run sorting job background_tasks.add_task( sorting_background_task, - payload=data.dict(), - run_identifier=run_identifier + run_kwargs=run_kwargs, + source_data_kwargs=source_data_kwargs, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, + output_data_kwargs=output_data_kwargs, ) except Exception as e: @@ -83,5 +129,5 @@ async def route_run_sorting(data: SortingData, background_tasks: BackgroundTasks raise HTTPException(status_code=500, detail="Internal server error") return JSONResponse(content={ "message": "Sorting job submitted", - "run_identifier": run.identifier, + "run_identifier": run_kwargs.run_identifier, }) \ No newline at end of file diff --git a/rest/routes/user.py b/rest/routes/user.py index 316a0c7..22e46ca 100644 --- a/rest/routes/user.py +++ b/rest/routes/user.py @@ -2,9 +2,9 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from core.logger import logger -from core.settings import settings -from clients.database import DatabaseClient +from ..core.logger import logger +from ..core.settings import settings +from ..clients.database import DatabaseClient router = APIRouter()