From c31501eb69a58a3651a61740925627617dec5dd3 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 11 Oct 2023 15:28:03 +0200 Subject: [PATCH 1/9] change worker to docker proxy --- containers/main.py | 10 +- docker-compose-dev.yml | 90 +++--- rest/Dockerfile | 5 +- rest/README.md | 7 + rest/__init__.py | 0 rest/clients/aws.py | 2 +- rest/clients/database.py | 2 +- rest/clients/local_docker.py | 87 ++++++ ...l_worker.py => local_worker_deprecated.py} | 0 rest/core/__init__.py | 0 rest/data/__init__.py | 0 rest/db/utils.py | 2 +- rest/main.py | 14 +- rest/models/sorting.py | 288 ++++++++++++++++-- rest/requirements.txt | 9 +- rest/routes/__init__.py | 0 rest/routes/dandi.py | 4 +- rest/routes/runs.py | 13 +- rest/routes/sorting.py | 127 +++++--- rest/routes/user.py | 6 +- 20 files changed, 531 insertions(+), 135 deletions(-) create mode 100644 rest/README.md create mode 100644 rest/__init__.py create mode 100644 rest/clients/local_docker.py rename rest/clients/{local_worker.py => local_worker_deprecated.py} (100%) create mode 100644 rest/core/__init__.py create mode 100644 rest/data/__init__.py create mode 100644 rest/routes/__init__.py diff --git a/containers/main.py b/containers/main.py index 0e108e7..13349cc 100644 --- a/containers/main.py +++ b/containers/main.py @@ -109,10 +109,8 @@ "d_prime", ] -sparsity_params = dict(method="radius", radius_um=100) - postprocessing_params = dict( - sparsity=sparsity_params, + sparsity=dict(method="radius", radius_um=100), waveforms_deduplicate=dict( ms_before=0.5, ms_after=1.5, @@ -148,7 +146,11 @@ locations=dict(method="monopolar_triangulation"), template_metrics=dict(upsampling_factor=10, sparsity=None), principal_components=dict(n_components=5, mode="by_channel_local", whiten=True), - quality_metrics=dict(qm_params=qm_params, metric_names=qm_metric_names, n_jobs=1), + quality_metrics=dict( + qm_params=qm_params, + metric_names=qm_metric_names, + n_jobs=1 + ), ) curation_params = dict( diff --git a/docker-compose-dev.yml b/docker-compose-dev.yml index 4df075c..16e2fea 100644 --- a/docker-compose-dev.yml +++ b/docker-compose-dev.yml @@ -1,21 +1,30 @@ version: "3" services: - frontend: - build: - context: frontend - dockerfile: Dockerfile - image: si-sorting-frontend - container_name: si-sorting-frontend - command: ["npm", "run", "start"] + docker-proxy: + image: bobrik/socat + container_name: si-docker-proxy + command: "TCP4-LISTEN:2375,fork,reuseaddr UNIX-CONNECT:/var/run/docker.sock" ports: - - "5173:5173" - environment: - DEPLOY_MODE: compose + - "2376:2375" volumes: - - ./frontend:/app - depends_on: - - rest + - /var/run/docker.sock:/var/run/docker.sock + + # frontend: + # build: + # context: frontend + # dockerfile: Dockerfile + # image: si-sorting-frontend + # container_name: si-sorting-frontend + # command: ["npm", "run", "start"] + # ports: + # - "5173:5173" + # environment: + # DEPLOY_MODE: compose + # volumes: + # - ./frontend:/app + # depends_on: + # - rest rest: build: @@ -23,6 +32,7 @@ services: dockerfile: Dockerfile image: si-sorting-rest container_name: si-sorting-rest + command: ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--reload", "--reload-dir", "/app"] ports: - "8000:8000" environment: @@ -39,33 +49,33 @@ services: depends_on: - database - worker: - build: - context: containers - dockerfile: Dockerfile.combined - image: si-sorting-worker - # image: ghcr.io/catalystneuro/si-sorting-worker:latest - container_name: si-sorting-worker - ports: - - "5000:5000" - environment: - WORKER_DEPLOY_MODE: compose - AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} - AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} - AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} - DANDI_API_KEY: ${DANDI_API_KEY} - DANDI_API_KEY_STAGING: ${DANDI_API_KEY_STAGING} - volumes: - - ./containers:/app - - ./results:/results - - ./logs:/logs - deploy: - resources: - reservations: - devices: - - driver: nvidia - count: 1 - capabilities: [gpu] + # worker: + # build: + # context: containers + # dockerfile: Dockerfile.combined + # image: si-sorting-worker + # # image: ghcr.io/catalystneuro/si-sorting-worker:latest + # container_name: si-sorting-worker + # ports: + # - "5000:5000" + # environment: + # WORKER_DEPLOY_MODE: compose + # AWS_DEFAULT_REGION: ${AWS_DEFAULT_REGION} + # AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID} + # AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY} + # DANDI_API_KEY: ${DANDI_API_KEY} + # DANDI_API_KEY_STAGING: ${DANDI_API_KEY_STAGING} + # volumes: + # - ./containers:/app + # - ./results:/results + # - ./logs:/logs + # deploy: + # resources: + # reservations: + # devices: + # - driver: nvidia + # count: 1 + # capabilities: [gpu] database: image: postgres:latest diff --git a/rest/Dockerfile b/rest/Dockerfile index df71593..48d746c 100644 --- a/rest/Dockerfile +++ b/rest/Dockerfile @@ -19,4 +19,7 @@ ENV SI_CLOUD_ENV production ENV PYTHONUNBUFFERED=1 EXPOSE 8000 -CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8000", "--reload"] \ No newline at end of file + +WORKDIR / + +CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"] \ No newline at end of file diff --git a/rest/README.md b/rest/README.md new file mode 100644 index 0000000..3889cdf --- /dev/null +++ b/rest/README.md @@ -0,0 +1,7 @@ +# REST API + +To run REST API in local environment, from the root directory of the project run: + +```bash +uvicorn rest.main:app --host 0.0.0.0 --port 8000 --workers 4 --reload +``` \ No newline at end of file diff --git a/rest/__init__.py b/rest/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/clients/aws.py b/rest/clients/aws.py index 04f89c1..d347e2a 100644 --- a/rest/clients/aws.py +++ b/rest/clients/aws.py @@ -1,7 +1,7 @@ import boto3 import enum -from core import settings +from ..core import settings class JobStatus(enum.Enum): diff --git a/rest/clients/database.py b/rest/clients/database.py index 48689e1..a552525 100644 --- a/rest/clients/database.py +++ b/rest/clients/database.py @@ -4,7 +4,7 @@ import ast import json -from db.models import User, DataSource, Run +from ..db.models import User, DataSource, Run class DatabaseClient: diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py new file mode 100644 index 0000000..587330e --- /dev/null +++ b/rest/clients/local_docker.py @@ -0,0 +1,87 @@ +from pathlib import Path +import docker + +from ..core.logger import logger +from ..models.sorting import ( + RunKwargs, + SourceDataKwargs, + RecordingKwargs, + PreprocessingKwargs, + SorterKwargs, + PostprocessingKwargs, + CurationKwargs, + VisualizationKwargs, +) + + +class LocalDockerClient: + + def __init__(self, base_url: str = "tcp://docker-proxy:2375"): + self.logger = logger + self.client = docker.DockerClient(base_url=base_url) + + def run_sorting( + self, + run_kwargs: RunKwargs, + source_data_kwargs: SourceDataKwargs, + recording_kwargs: RecordingKwargs, + preprocessing_kwargs: PreprocessingKwargs, + sorter_kwargs: SorterKwargs, + postprocessing_kwargs: PostprocessingKwargs, + curation_kwargs: CurationKwargs, + visualization_kwargs: VisualizationKwargs, + ) -> None: + # Pass kwargs as environment variables to the container + env_vars = dict( + SI_RUN_KWARGS=run_kwargs.json(), + SI_SOURCE_DATA_KWARGS=source_data_kwargs.json(), + SI_RECORDING_KWARGS=recording_kwargs.json(), + SI_PREPROCESSING_KWARGS=preprocessing_kwargs.json(), + SI_SORTER_KWARGS=sorter_kwargs.json(), + SI_POSTPROCESSING_KWARGS=postprocessing_kwargs.json(), + SI_CURATION_KWARGS=curation_kwargs.json(), + SI_VISUALIZATION_KWARGS=visualization_kwargs.json(), + ) + + # Local volumes to mount + local_directory = Path(".").absolute() + logs_directory = local_directory / "logs" + results_directory = local_directory / "results" + volumes = { + logs_directory: {'bind': '/logs', 'mode': 'rw'}, + results_directory: {'bind': '/results', 'mode': 'rw'}, + } + + container = self.client.containers.run( + image='python:slim', + command=['python', '-c', 'import os; print(os.environ.get("SI_RUN_KWARGS"))'], + detach=True, + environment=env_vars, + volumes=volumes, + device_requests=[ + docker.types.DeviceRequest( + device_ids=["0"], + capabilities=[['gpu']] + ) + ] + ) + # if response.status_code == 200: + # self.logger.info("Success!") + # else: + # self.logger.info(f"Error {response.status_code}: {response.content}") + + + def get_run_logs(self, run_identifier): + # TODO: Implement this + self.logger.info("Getting logs...") + # response = requests.get(self.url + "/logs", params={"run_identifier": run_identifier}) + # if response.status_code == 200: + # logs = response.content.decode('utf-8') + # if "Error running sorter" in logs: + # return "fail", logs + # elif "Sorting job completed successfully!" in logs: + # return "success", logs + # return "running", logs + # else: + # self.logger.info(f"Error {response.status_code}: {response.content}") + # return "fail", f"Logs couldn't be retrieved. Error {response.status_code}: {response.content}" \ No newline at end of file diff --git a/rest/clients/local_worker.py b/rest/clients/local_worker_deprecated.py similarity index 100% rename from rest/clients/local_worker.py rename to rest/clients/local_worker_deprecated.py diff --git a/rest/core/__init__.py b/rest/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/data/__init__.py b/rest/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/db/utils.py b/rest/db/utils.py index 4e7666c..097df5b 100644 --- a/rest/db/utils.py +++ b/rest/db/utils.py @@ -2,7 +2,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import relationship, sessionmaker -from db.models import Base, User, DataSource, Run +from .models import Base, User, DataSource, Run def initialize_db(db: str): diff --git a/rest/main.py b/rest/main.py index 4ed5098..c2959a1 100644 --- a/rest/main.py +++ b/rest/main.py @@ -5,13 +5,13 @@ from fastapi.responses import JSONResponse from pathlib import Path -from core.settings import settings -from routes.user import router as router_user -from routes.dandi import router as router_dandi -from routes.sorting import router as router_sorting -from routes.runs import router as router_runs -from clients.dandi import DandiClient -from db.utils import initialize_db +from .core.settings import settings +from .routes.user import router as router_user +from .routes.dandi import router as router_dandi +from .routes.sorting import router as router_sorting +from .routes.runs import router as router_runs +from .clients.dandi import DandiClient +from .db.utils import initialize_db import logging diff --git a/rest/models/sorting.py b/rest/models/sorting.py index c43709e..bed097b 100644 --- a/rest/models/sorting.py +++ b/rest/models/sorting.py @@ -1,43 +1,281 @@ -from pydantic import BaseModel -from typing import List +from pydantic import BaseModel, Field, Extra +from typing import Optional, Dict, List, Union, Tuple from enum import Enum -class OutputDestination(str, Enum): +# ------------------------------ +# Run Models +# ------------------------------ +class RunAt(str, Enum): + aws = "aws" + local = "local" + +class RunKwargs(BaseModel): + run_at: RunAt = Field(..., description="Where to run the sorting job. Choose from: aws, local.") + run_identifier: str = Field(..., description="Unique identifier for the run.") + run_description: str = Field(..., description="Description of the run.") + test_with_toy_recording: bool = Field(default=False, description="Whether to test with a toy recording.") + test_with_subrecording: bool = Field(default=False, description="Whether to test with a subrecording.") + test_subrecording_n_frames: Optional[int] = Field(default=30000, description="Number of frames to use for the subrecording.") + log_to_file: bool = Field(default=False, description="Whether to log to a file.") + + +# ------------------------------ +# Source Data Models +# ------------------------------ +class SourceName(str, Enum): s3 = "s3" dandi = "dandi" local = "local" - class SourceDataType(str, Enum): nwb = "nwb" spikeglx = "spikeglx" +class SourceDataKwargs(BaseModel): + source_name: SourceName = Field(..., description="Source of input data. Choose from: s3, dandi, local.") + source_data_paths: Dict[str, str] = Field(..., description="Dictionary with paths to source data. Keys are names of data files, values are urls.") + source_data_type: SourceDataType = Field(..., description="Type of input data. Choose from: nwb, spikeglx.") + -class Source(str, Enum): +# ------------------------------ +# Output Data Models +# ------------------------------ +class OutputDestination(str, Enum): s3 = "s3" dandi = "dandi" + local = "local" +class OutputDataKwargs(BaseModel): + output_destination: OutputDestination = Field(..., description="Destination of output data. Choose from: s3, dandi, local.") + output_path: str = Field(..., description="Path to output data.") -class RunAt(str, Enum): - aws = "aws" - local = "local" +# ------------------------------ +# Recording Models +# ------------------------------ +class RecordingKwargs(BaseModel, extra=Extra.allow): + pass + + +# ------------------------------ +# Preprocessing Models +# ------------------------------ +class HighpassFilter(BaseModel): + freq_min: float = Field(default=300.0, description="Minimum frequency for the highpass filter") + margin_ms: float = Field(default=5.0, description="Margin in milliseconds") + +class PhaseShift(BaseModel): + margin_ms: float = Field(default=100.0, description="Margin in milliseconds for phase shift") + +class DetectBadChannels(BaseModel): + method: str = Field(default="coherence+psd", description="Method to detect bad channels") + dead_channel_threshold: float = Field(default=-0.5, description="Threshold for dead channel") + noisy_channel_threshold: float = Field(default=1.0, description="Threshold for noisy channel") + outside_channel_threshold: float = Field(default=-0.3, description="Threshold for outside channel") + n_neighbors: int = Field(default=11, description="Number of neighbors") + seed: int = Field(default=0, description="Seed value") + +class CommonReference(BaseModel): + reference: str = Field(default="global", description="Type of reference") + operator: str = Field(default="median", description="Operator used for common reference") + +class HighpassSpatialFilter(BaseModel): + n_channel_pad: int = Field(default=60, description="Number of channels to pad") + n_channel_taper: Optional[int] = Field(default=None, description="Number of channels to taper") + direction: str = Field(default="y", description="Direction for the spatial filter") + apply_agc: bool = Field(default=True, description="Whether to apply automatic gain control") + agc_window_length_s: float = Field(default=0.01, description="Window length in seconds for AGC") + highpass_butter_order: int = Field(default=3, description="Order for the Butterworth filter") + highpass_butter_wn: float = Field(default=0.01, description="Natural frequency for the Butterworth filter") + +class PreprocessingKwargs(BaseModel): + preprocessing_strategy: str = Field(default="cmr", description="Strategy for preprocessing") + highpass_filter: HighpassFilter + phase_shift: PhaseShift + detect_bad_channels: DetectBadChannels + remove_out_channels: bool = Field(default=False, description="Flag to remove out channels") + remove_bad_channels: bool = Field(default=False, description="Flag to remove bad channels") + max_bad_channel_fraction_to_remove: float = Field(default=1.1, description="Maximum fraction of bad channels to remove") + common_reference: CommonReference + highpass_spatial_filter: HighpassSpatialFilter + + +# ------------------------------ +# Sorter Models +# ------------------------------ +class SorterName(str, Enum): + ironclust = "ironclust" + kilosort2 = "kilosort2" + kilosort25 = "kilosort25" + kilosort3 = "kilosort3" + spykingcircus = "spykingcircus" + +class SorterKwargs(BaseModel, extra=Extra.allow): + sorter_name: SorterName = Field(..., description="Name of the sorter to use.") + + +# ------------------------------ +# Postprocessing Models +# ------------------------------ +class PresenceRatio(BaseModel): + bin_duration_s: float = Field(60, description="Duration of the bin in seconds.") + +class SNR(BaseModel): + peak_sign: str = Field("neg", description="Sign of the peak.") + peak_mode: str = Field("extremum", description="Mode of the peak.") + random_chunk_kwargs_dict: Optional[dict] = Field(None, description="Random chunk arguments.") + +class ISIViolation(BaseModel): + isi_threshold_ms: float = Field(1.5, description="ISI threshold in milliseconds.") + min_isi_ms: float = Field(0., description="Minimum ISI in milliseconds.") + +class RPViolation(BaseModel): + refractory_period_ms: float = Field(1., description="Refractory period in milliseconds.") + censored_period_ms: float = Field(0.0, description="Censored period in milliseconds.") + +class SlidingRPViolation(BaseModel): + min_spikes: int = Field(0, description="Contamination is set to np.nan if the unit has less than this many spikes across all segments.") + bin_size_ms: float = Field(0.25, description="The size of binning for the autocorrelogram in ms, by default 0.25.") + window_size_s: float = Field(1, description="Window in seconds to compute correlogram, by default 1.") + exclude_ref_period_below_ms: float = Field(0.5, description="Refractory periods below this value are excluded, by default 0.5") + max_ref_period_ms: float = Field(10, description="Maximum refractory period to test in ms, by default 10 ms.") + contamination_values: Optional[list] = Field(None, description="The contamination values to test, by default np.arange(0.5, 35, 0.5) %") + +class PeakSign(str, Enum): + neg = "neg" + pos = "pos" + both = "both" + +class AmplitudeCutoff(BaseModel): + peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + num_histogram_bins: int = Field(100, description="The number of bins to use to compute the amplitude histogram.") + histogram_smoothing_value: int = Field(3, description="Controls the smoothing applied to the amplitude histogram.") + amplitudes_bins_min_ratio: int = Field(5, description="The minimum ratio between number of amplitudes for a unit and the number of bins. If the ratio is less than this threshold, the amplitude_cutoff for the unit is set to NaN.") + +class AmplitudeMedian(BaseModel): + peak_sign: PeakSign = Field("neg", description="The sign of the peaks.") + +class NearestNeighbor(BaseModel): + max_spikes: int = Field(10000, description="The number of spikes to use, per cluster. Note that the calculation can be very slow when this number is >20000.") + min_spikes: int = Field(10, description="Minimum number of spikes.") + n_neighbors: int = Field(4, description="The number of neighbors to use.") + +class NNIsolation(NearestNeighbor): + n_components: int = Field(10, description="The number of PC components to use to project the snippets to.") + radius_um: int = Field(100, description="The radius, in um, that channels need to be within the peak channel to be included.") + +class QMParams(BaseModel): + presence_ratio: PresenceRatio + snr: SNR + isi_violation: ISIViolation + rp_violation: RPViolation + sliding_rp_violation: SlidingRPViolation + amplitude_cutoff: AmplitudeCutoff + amplitude_median: AmplitudeMedian + nearest_neighbor: NearestNeighbor + nn_isolation: NNIsolation + nn_noise_overlap: NNIsolation + +class QualityMetrics(BaseModel): + qm_params: QMParams = Field(..., description="Quality metric parameters.") + metric_names: List[str] = Field(..., description="List of metric names to compute.") + n_jobs: int = Field(1, description="Number of jobs.") + +class Sparsity(BaseModel): + method: str = Field("radius", description="Method for determining sparsity.") + radius_um: int = Field(100, description="Radius in micrometers for sparsity.") + +class Waveforms(BaseModel): + ms_before: float = Field(3.0, description="Milliseconds before") + ms_after: float = Field(4.0, description="Milliseconds after") + max_spikes_per_unit: int = Field(500, description="Maximum spikes per unit") + return_scaled: bool = Field(True, description="Flag to determine if results should be scaled") + dtype: Optional[str] = Field(None, description="Data type for the waveforms") + precompute_template: Tuple[str, str] = Field(("average", "std"), description="Precomputation template method") + use_relative_path: bool = Field(True, description="Use relative paths") + +class SpikeAmplitudes(BaseModel): + peak_sign: str = Field("neg", description="Sign of the peak") + return_scaled: bool = Field(True, description="Flag to determine if amplitudes should be scaled") + outputs: str = Field("concatenated", description="Output format for the spike amplitudes") + +class Similarity(BaseModel): + method: str = Field("cosine_similarity", description="Method to compute similarity") + +class Correlograms(BaseModel): + window_ms: float = Field(100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(2.0, description="Size of the bin in milliseconds") + +class ISIS(BaseModel): + window_ms: float = Field(100.0, description="Size of the window in milliseconds") + bin_ms: float = Field(5.0, description="Size of the bin in milliseconds") + +class Locations(BaseModel): + method: str = Field("monopolar_triangulation", description="Method to determine locations") + +class TemplateMetrics(BaseModel): + upsampling_factor: int = Field(10, description="Upsampling factor") + sparsity: Optional[str] = Field(None, description="Sparsity method") + +class PrincipalComponents(BaseModel): + n_components: int = Field(5, description="Number of principal components") + mode: str = Field("by_channel_local", description="Mode of principal component analysis") + whiten: bool = Field(True, description="Whiten the components") + +class PostprocessingKwargs(BaseModel): + sparsity: Sparsity + waveforms_deduplicate: Waveforms + waveforms: Waveforms + spike_amplitudes: SpikeAmplitudes + similarity: Similarity + correlograms: Correlograms + isis: ISIS + locations: Locations + template_metrics: TemplateMetrics + principal_components: PrincipalComponents + quality_metrics: QualityMetrics + +# ------------------------------ +# Curation Models +# ------------------------------ +class CurationKwargs(BaseModel): + duplicate_threshold: float = Field(0.9, description="Threshold for duplicate units") + isi_violations_ratio_threshold: float = Field(0.5, description="Threshold for ISI violations ratio") + presence_ratio_threshold: float = Field(0.8, description="Threshold for presence ratio") + amplitude_cutoff_threshold: float = Field(0.1, description="Threshold for amplitude cutoff") + + +# ------------------------------ +# Visualization Models +# ------------------------------ +class Timeseries(BaseModel): + n_snippets_per_segment: int = Field(2, description="Number of snippets per segment") + snippet_duration_s: float = Field(0.5, description="Duration of the snippet in seconds") + skip: bool = Field(False, description="Flag to skip") + +class Detection(BaseModel): + method: str = Field("locally_exclusive", description="Method for detection") + peak_sign: str = Field("neg", description="Sign of the peak") + detect_threshold: int = Field(5, description="Detection threshold") + exclude_sweep_ms: float = Field(0.1, description="Exclude sweep in milliseconds") + +class Localization(BaseModel): + ms_before: float = Field(0.1, description="Milliseconds before") + ms_after: float = Field(0.3, description="Milliseconds after") + local_radius_um: float = Field(100.0, description="Local radius in micrometers") + +class Drift(BaseModel): + detection: Detection + localization: Localization + n_skip: int = Field(30, description="Number of skips") + alpha: float = Field(0.15, description="Alpha value") + vmin: int = Field(-200, description="Minimum value") + vmax: int = Field(0, description="Maximum value") + cmap: str = Field("Greys_r", description="Colormap") + figsize: Tuple[int, int] = Field((10, 10), description="Figure size") + +class VisualizationKwargs(BaseModel): + timeseries: Timeseries + drift: Drift -class SortingData(BaseModel): - run_at: RunAt = "local" - run_identifier: str = None - run_description: str = None - source: Source = None - source_data_type: SourceDataType = None - source_data_paths: dict = None - subject_metadata: dict = None - recording_kwargs: dict = None - output_destination: OutputDestination = None - output_path: str = None - sorters_names_list: List[str] = None - sorters_kwargs: dict = None - test_with_toy_recording: bool = None - test_with_subrecording: bool = None - test_subrecording_n_frames: int = None - log_to_file: bool = None \ No newline at end of file diff --git a/rest/requirements.txt b/rest/requirements.txt index fb8d014..8fff453 100644 --- a/rest/requirements.txt +++ b/rest/requirements.txt @@ -1,8 +1,11 @@ -fastapi[all]==0.95.0 -dandi==0.56.0 +fastapi==0.103.2 +uvicorn==0.23.2 +pydantic==1.10.13 +dandi==0.56.2 fsspec==2023.3.0 requests==2.28.2 aiohttp==3.8.4 boto3==1.26.102 SQLAlchemy==2.0.8 -psycopg2==2.9.5 \ No newline at end of file +psycopg2==2.9.5 +docker==6.1.3 \ No newline at end of file diff --git a/rest/routes/__init__.py b/rest/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/rest/routes/dandi.py b/rest/routes/dandi.py index 52395f7..035b06b 100644 --- a/rest/routes/dandi.py +++ b/rest/routes/dandi.py @@ -2,8 +2,8 @@ from fastapi.responses import JSONResponse from typing import List -from clients.dandi import DandiClient -from core.settings import settings +from ..clients.dandi import DandiClient +from ..core.settings import settings router = APIRouter() diff --git a/rest/routes/runs.py b/rest/routes/runs.py index a64eb61..e22e294 100644 --- a/rest/routes/runs.py +++ b/rest/routes/runs.py @@ -2,11 +2,11 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from core.logger import logger -from core.settings import settings -from clients.database import DatabaseClient -from clients.aws import AWSClient -from clients.local_worker import LocalWorkerClient +from ..core.logger import logger +from ..core.settings import settings +from ..clients.database import DatabaseClient +from ..clients.aws import AWSClient +from ..clients.local_docker import LocalDockerClient router = APIRouter() @@ -37,7 +37,8 @@ def get_run_info(run_id: str): if "Error running sorter" in run_logs: status = "fail" elif run_info["run_at"] == "local": - local_worker_client = LocalWorkerClient() + # TODO: Implement this + local_worker_client = LocalDockerClient() status, run_logs = local_worker_client.get_run_logs(run_identifier=run_info['identifier']) else: status = "running" diff --git a/rest/routes/sorting.py b/rest/routes/sorting.py index cc97a7d..188724c 100644 --- a/rest/routes/sorting.py +++ b/rest/routes/sorting.py @@ -2,28 +2,56 @@ from fastapi.responses import JSONResponse from datetime import datetime -from core.logger import logger -from core.settings import settings -from clients.dandi import DandiClient -from clients.aws import AWSClient -from clients.local_worker import LocalWorkerClient -from clients.database import DatabaseClient -from models.sorting import SortingData +from ..core.logger import logger +from ..core.settings import settings +from ..clients.dandi import DandiClient +from ..clients.aws import AWSClient +from ..clients.local_docker import LocalDockerClient +from ..clients.database import DatabaseClient +from ..models.sorting import ( + RunKwargs, + SourceDataKwargs, + RecordingKwargs, + PreprocessingKwargs, + SorterKwargs, + PostprocessingKwargs, + CurationKwargs, + VisualizationKwargs, +) router = APIRouter() -def sorting_background_task(payload, run_identifier): +def sorting_background_task( + run_kwargs: RunKwargs, + source_data_kwargs: SourceDataKwargs, + recording_kwargs: RecordingKwargs, + preprocessing_kwargs: PreprocessingKwargs, + sorter_kwargs: SorterKwargs, + postprocessing_kwargs: PostprocessingKwargs, + curation_kwargs: CurationKwargs, + visualization_kwargs: VisualizationKwargs, +): # Run sorting and update db entry status db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) - run_at = payload.get("run_at", None) + run_at = run_kwargs.run_at try: logger.info(f"Run job at: {run_at}") if run_at == "local": - client_local_worker = LocalWorkerClient() - client_local_worker.run_sorting(**payload) + client_local_worker = LocalDockerClient() + client_local_worker.run_sorting( + run_kwargs=run_kwargs, + source_data_kwargs=source_data_kwargs, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, + ) elif run_at == "aws": + # TODO: Implement this job_kwargs = {k.upper(): v for k, v in payload.items()} job_kwargs["DANDI_API_KEY"] = settings.DANDI_API_KEY client_aws = AWSClient() @@ -33,49 +61,66 @@ def sorting_background_task(payload, run_identifier): job_definition=settings.AWS_BATCH_JOB_DEFINITION, job_kwargs=job_kwargs, ) - db_client.update_run(run_identifier=run_identifier, key="status", value="running") + # db_client.update_run(run_identifier=run_identifier, key="status", value="running") except Exception as e: - logger.exception(f"Error running sorting job: {run_identifier}.\n {e}") - db_client.update_run(run_identifier=run_identifier, key="status", value="fail") + logger.exception(f"Error running sorting job: {run_kwargs.run_identifier}.\n {e}") + db_client.update_run(run_identifier=run_kwargs.run_identifier, key="status", value="fail") @router.post("/run", response_description="Run Sorting", tags=["sorting"]) -async def route_run_sorting(data: SortingData, background_tasks: BackgroundTasks) -> JSONResponse: - if not data.run_identifier: +async def route_run_sorting( + run_kwargs: RunKwargs, + source_data_kwargs: SourceDataKwargs, + recording_kwargs: RecordingKwargs, + preprocessing_kwargs: PreprocessingKwargs, + sorter_kwargs: SorterKwargs, + postprocessing_kwargs: PostprocessingKwargs, + curation_kwargs: CurationKwargs, + visualization_kwargs: VisualizationKwargs, + background_tasks: BackgroundTasks +) -> JSONResponse: + if not run_kwargs.run_identifier: run_identifier = datetime.now().strftime("%Y%m%d%H%M%S") else: - run_identifier = data.run_identifier + run_identifier = run_kwargs.run_identifier try: # Create Database entries db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) user = db_client.get_user_info(username="admin") - data_source = db_client.create_data_source( - name=data.run_identifier, - description=data.run_description, - user_id=user.id, - source=data.source, - source_data_type=data.source_data_type, - source_data_paths=str(data.source_data_paths), - recording_kwargs=str(data.recording_kwargs), - ) - run = db_client.create_run( - run_at=data.run_at, - identifier=run_identifier, - description=data.run_description, - last_run=datetime.now().strftime("%Y/%m/%d %H:%M:%S"), - status="running", - data_source_id=data_source.id, - user_id=user.id, - metadata=str(data.json()), - output_destination=data.output_destination, - output_path=data.output_path, - ) + # TODO: Create data source and run entries in database + # data_source = db_client.create_data_source( + # name=run_kwargs.run_identifier, + # description=run_kwargs.run_description, + # user_id=user.id, + # source=source_data_kwargs.source_name, + # source_data_type=source_data_kwargs.source_data_type, + # source_data_paths=str(source_data_kwargs.source_data_paths), + # recording_kwargs=str(recording_kwargs.dict()), + # ) + # run = db_client.create_run( + # run_at=data.run_at, + # identifier=run_identifier, + # description=data.run_description, + # last_run=datetime.now().strftime("%Y/%m/%d %H:%M:%S"), + # status="running", + # data_source_id=data_source.id, + # user_id=user.id, + # metadata=str(data.json()), + # output_destination=data.output_destination, + # output_path=data.output_path, + # ) # Run sorting job background_tasks.add_task( sorting_background_task, - payload=data.dict(), - run_identifier=run_identifier + run_kwargs=run_kwargs, + source_data_kwargs=source_data_kwargs, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, ) except Exception as e: @@ -83,5 +128,5 @@ async def route_run_sorting(data: SortingData, background_tasks: BackgroundTasks raise HTTPException(status_code=500, detail="Internal server error") return JSONResponse(content={ "message": "Sorting job submitted", - "run_identifier": run.identifier, + "run_identifier": run_kwargs.run_identifier, }) \ No newline at end of file diff --git a/rest/routes/user.py b/rest/routes/user.py index 316a0c7..22e46ca 100644 --- a/rest/routes/user.py +++ b/rest/routes/user.py @@ -2,9 +2,9 @@ from fastapi.responses import JSONResponse from pydantic import BaseModel -from core.logger import logger -from core.settings import settings -from clients.database import DatabaseClient +from ..core.logger import logger +from ..core.settings import settings +from ..clients.database import DatabaseClient router = APIRouter() From 40e40b210d1b6a46ca42016bd9e34891718312a6 Mon Sep 17 00:00:00 2001 From: luiz Date: Wed, 11 Oct 2023 15:59:30 +0200 Subject: [PATCH 2/9] name --- rest/clients/local_docker.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py index 587330e..ed738b0 100644 --- a/rest/clients/local_docker.py +++ b/rest/clients/local_docker.py @@ -53,6 +53,7 @@ def run_sorting( } container = self.client.containers.run( + name=f'si-sorting-run-{run_kwargs.run_identifier}', image='python:slim', command=['python', '-c', 'import os; print(os.environ.get("SI_RUN_KWARGS"))'], detach=True, @@ -64,12 +65,7 @@ def run_sorting( capabilities=[['gpu']] ) ] - ) - # if response.status_code == 200: - # self.logger.info("Success!") - # else: - # self.logger.info(f"Error {response.status_code}: {response.content}") - + ) def get_run_logs(self, run_identifier): # TODO: Implement this From 63962cba9085527ea73f4e98501acda8c2f1a076 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 12:35:38 +0200 Subject: [PATCH 3/9] wip - refactoring main.py --- containers/Dockerfile.ks2_5 | 60 +++--- containers/README_costs.md | 2 +- containers/example_formats.py | 154 +++++++++++++++ containers/main.py | 354 +++++++++++++--------------------- containers/source_format.py | 23 --- containers/utils.py | 7 + rest/models/sorting.py | 2 +- 7 files changed, 330 insertions(+), 272 deletions(-) create mode 100644 containers/example_formats.py delete mode 100644 containers/source_format.py diff --git a/containers/Dockerfile.ks2_5 b/containers/Dockerfile.ks2_5 index fc99494..602d67b 100644 --- a/containers/Dockerfile.ks2_5 +++ b/containers/Dockerfile.ks2_5 @@ -1,34 +1,34 @@ # Spike sorters image FROM spikeinterface/kilosort2_5-compiled-base:0.2.0 as ks25base -# NVIDIA-ready Image -FROM nvidia/cuda:11.6.2-base-ubuntu20.04 - -# Installing Python with miniconda -RUN apt-get update && \ - apt-get install -y build-essential && \ - apt-get install -y wget && \ - apt-get install -y git && \ - apt-get clean && \ - rm -rf /var/lib/apt/lists/* - -ENV CONDA_DIR /home/miniconda3 -ENV LATEST_CONDA_SCRIPT "Miniconda3-py39_23.5.2-0-Linux-x86_64.sh" - -RUN wget --quiet https://repo.anaconda.com/miniconda/$LATEST_CONDA_SCRIPT -O ~/miniconda.sh && \ - bash ~/miniconda.sh -b -p $CONDA_DIR && \ - rm ~/miniconda.sh -ENV PATH=$CONDA_DIR/bin:$PATH - -# Bring Sorter and matlab-related files -COPY --from=ks25base /usr/bin/mlrtapp/ks2_5_compiled /usr/bin/mlrtapp/ks2_5_compiled -ENV PATH="/usr/bin/mlrtapp:${PATH}" -COPY --from=ks25base /opt/matlabruntime /opt/matlabruntime -ENV PATH="/opt/matlabruntime:${PATH}" -COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libXt.so.6 /usr/lib/x86_64-linux-gnu/libXt.so.6 -COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libSM.so.6 /usr/lib/x86_64-linux-gnu/libSM.so.6 -COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libICE.so.6 /usr/lib/x86_64-linux-gnu/libICE.so.6 -ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/matlabruntime/R2022b/runtime/glnxa64:/opt/matlabruntime/R2022b/bin/glnxa64:/opt/matlabruntime/R2022b/sys/os/glnxa64:/opt/matlabruntime/R2022b/sys/opengl/lib/glnxa64:/opt/matlabruntime/R2022b/extern/bin/glnxa64 +# # NVIDIA-ready Image +# FROM nvidia/cuda:11.6.2-base-ubuntu20.04 + +# # Installing Python with miniconda +# RUN apt-get update && \ +# apt-get install -y build-essential && \ +# apt-get install -y wget && \ +# apt-get install -y git && \ +# apt-get clean && \ +# rm -rf /var/lib/apt/lists/* + +# ENV CONDA_DIR /home/miniconda3 +# ENV LATEST_CONDA_SCRIPT "Miniconda3-py39_23.5.2-0-Linux-x86_64.sh" + +# RUN wget --quiet https://repo.anaconda.com/miniconda/$LATEST_CONDA_SCRIPT -O ~/miniconda.sh && \ +# bash ~/miniconda.sh -b -p $CONDA_DIR && \ +# rm ~/miniconda.sh +# ENV PATH=$CONDA_DIR/bin:$PATH + +# # Bring Sorter and matlab-related files +# COPY --from=ks25base /usr/bin/mlrtapp/ks2_5_compiled /usr/bin/mlrtapp/ks2_5_compiled +# ENV PATH="/usr/bin/mlrtapp:${PATH}" +# COPY --from=ks25base /opt/matlabruntime /opt/matlabruntime +# ENV PATH="/opt/matlabruntime:${PATH}" +# COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libXt.so.6 /usr/lib/x86_64-linux-gnu/libXt.so.6 +# COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libSM.so.6 /usr/lib/x86_64-linux-gnu/libSM.so.6 +# COPY --from=ks25base /usr/lib/x86_64-linux-gnu/libICE.so.6 /usr/lib/x86_64-linux-gnu/libICE.so.6 +# ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/matlabruntime/R2022b/runtime/glnxa64:/opt/matlabruntime/R2022b/bin/glnxa64:/opt/matlabruntime/R2022b/sys/os/glnxa64:/opt/matlabruntime/R2022b/sys/opengl/lib/glnxa64:/opt/matlabruntime/R2022b/extern/bin/glnxa64 # Copy requirements and script COPY requirements.txt . @@ -37,11 +37,11 @@ RUN pip install -r requirements.txt WORKDIR /app COPY main.py . COPY utils.py . -COPY light_server.py . +# COPY light_server.py . RUN mkdir /data RUN mkdir /logs # Get Python stdout logs ENV PYTHONUNBUFFERED=1 -CMD ["python", "light_server.py"] \ No newline at end of file +CMD ["python", "main.py"] \ No newline at end of file diff --git a/containers/README_costs.md b/containers/README_costs.md index 7101b50..6ab4de2 100644 --- a/containers/README_costs.md +++ b/containers/README_costs.md @@ -5,7 +5,7 @@ ECR: - transfer out: $0.09 per GB transferred S3: -- storage: $0.023 per GB / mont +- storage: $0.023 per GB / month - requests: $0.005 (PUT, COPY, POST, LIST) or $0.0004 (GET, SELECT) - transfer out: 100GB free per month, after that $0.02 per GB (other aws services) or $0.09 per GB (outside of aws) diff --git a/containers/example_formats.py b/containers/example_formats.py new file mode 100644 index 0000000..6e1c70a --- /dev/null +++ b/containers/example_formats.py @@ -0,0 +1,154 @@ + + +source_data = { + "source": "dandi", # or "s3" + "source_data_type": "nwb", # or "spikeglx" + "source_data_paths": { + "file": "https://dandiarchive.org/dandiset/000003/0.210813.1807" + }, + "recording_kwargs": { + "electrical_series_name": "ElectricalSeries", + } +} + + +source_data_2 = { + "source": "s3", + "source_data_type": "spikeglx", + "source_data_paths": { + "file_bin": "s3://bucket/path/to/file.ap.bin", + "file_meta": "s3://bucket/path/to/file2.ap.meta", + }, + "recording_kwargs": {} +} + + +preprocessing_params = dict( + preprocessing_strategy="cmr", # 'destripe' or 'cmr' + highpass_filter=dict(freq_min=300.0, margin_ms=5.0), + phase_shift=dict(margin_ms=100.0), + detect_bad_channels=dict( + method="coherence+psd", + dead_channel_threshold=-0.5, + noisy_channel_threshold=1.0, + outside_channel_threshold=-0.3, + n_neighbors=11, + seed=0, + ), + remove_out_channels=False, + remove_bad_channels=False, + max_bad_channel_fraction_to_remove=1.1, + common_reference=dict(reference="global", operator="median"), + highpass_spatial_filter=dict( + n_channel_pad=60, + n_channel_taper=None, + direction="y", + apply_agc=True, + agc_window_length_s=0.01, + highpass_butter_order=3, + highpass_butter_wn=0.01, + ), +) + +qm_params = { + "presence_ratio": {"bin_duration_s": 60}, + "snr": {"peak_sign": "neg", "peak_mode": "extremum", "random_chunk_kwargs_dict": None}, + "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0}, + "rp_violation": {"refractory_period_ms": 1, "censored_period_ms": 0.0}, + "sliding_rp_violation": { + "bin_size_ms": 0.25, + "window_size_s": 1, + "exclude_ref_period_below_ms": 0.5, + "max_ref_period_ms": 10, + "contamination_values": None, + }, + "amplitude_cutoff": { + "peak_sign": "neg", + "num_histogram_bins": 100, + "histogram_smoothing_value": 3, + "amplitudes_bins_min_ratio": 5, + }, + "amplitude_median": {"peak_sign": "neg"}, + "nearest_neighbor": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4}, + "nn_isolation": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, + "nn_noise_overlap": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, +} +qm_metric_names = [ + "num_spikes", + "firing_rate", + "presence_ratio", + "snr", + "isi_violation", + "rp_violation", + "sliding_rp_violation", + "amplitude_cutoff", + "drift", + "isolation_distance", + "l_ratio", + "d_prime", +] + +postprocessing_params = dict( + sparsity=dict(method="radius", radius_um=100), + waveforms_deduplicate=dict( + ms_before=0.5, + ms_after=1.5, + max_spikes_per_unit=100, + return_scaled=False, + dtype=None, + precompute_template=("average",), + use_relative_path=True, + ), + waveforms=dict( + ms_before=3.0, + ms_after=4.0, + max_spikes_per_unit=500, + return_scaled=True, + dtype=None, + precompute_template=("average", "std"), + use_relative_path=True, + ), + spike_amplitudes=dict( + peak_sign="neg", + return_scaled=True, + outputs="concatenated", + ), + similarity=dict(method="cosine_similarity"), + correlograms=dict( + window_ms=100.0, + bin_ms=2.0, + ), + isis=dict( + window_ms=100.0, + bin_ms=5.0, + ), + locations=dict(method="monopolar_triangulation"), + template_metrics=dict(upsampling_factor=10, sparsity=None), + principal_components=dict(n_components=5, mode="by_channel_local", whiten=True), + quality_metrics=dict( + qm_params=qm_params, + metric_names=qm_metric_names, + n_jobs=1 + ), +) + +curation_params = dict( + duplicate_threshold=0.9, + isi_violations_ratio_threshold=0.5, + presence_ratio_threshold=0.8, + amplitude_cutoff_threshold=0.1, +) + +visualization_params = dict( + timeseries=dict(n_snippets_per_segment=2, snippet_duration_s=0.5, skip=False), + drift=dict( + detection=dict(method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), + localization=dict(ms_before=0.1, ms_after=0.3, local_radius_um=100.0), + n_skip=30, + alpha=0.15, + vmin=-200, + vmax=0, + cmap="Greys_r", + figsize=(10, 10), + ), +) diff --git a/containers/main.py b/containers/main.py index 13349cc..b0f4b43 100644 --- a/containers/main.py +++ b/containers/main.py @@ -1,4 +1,5 @@ import boto3 +import json import os import ast import subprocess @@ -33,6 +34,7 @@ from utils import ( make_logger, + validate_not_none, download_file_from_s3, upload_file_to_bucket, upload_all_files_to_bucket_folder, @@ -40,185 +42,31 @@ ) -### PARAMS: # TODO: probably we should pass a JSON file -n_jobs = os.cpu_count() -job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=False) - -preprocessing_params = dict( - preprocessing_strategy="cmr", # 'destripe' or 'cmr' - highpass_filter=dict(freq_min=300.0, margin_ms=5.0), - phase_shift=dict(margin_ms=100.0), - detect_bad_channels=dict( - method="coherence+psd", - dead_channel_threshold=-0.5, - noisy_channel_threshold=1.0, - outside_channel_threshold=-0.3, - n_neighbors=11, - seed=0, - ), - remove_out_channels=False, - remove_bad_channels=False, - max_bad_channel_fraction_to_remove=1.1, - common_reference=dict(reference="global", operator="median"), - highpass_spatial_filter=dict( - n_channel_pad=60, - n_channel_taper=None, - direction="y", - apply_agc=True, - agc_window_length_s=0.01, - highpass_butter_order=3, - highpass_butter_wn=0.01, - ), -) - -qm_params = { - "presence_ratio": {"bin_duration_s": 60}, - "snr": {"peak_sign": "neg", "peak_mode": "extremum", "random_chunk_kwargs_dict": None}, - "isi_violation": {"isi_threshold_ms": 1.5, "min_isi_ms": 0}, - "rp_violation": {"refractory_period_ms": 1, "censored_period_ms": 0.0}, - "sliding_rp_violation": { - "bin_size_ms": 0.25, - "window_size_s": 1, - "exclude_ref_period_below_ms": 0.5, - "max_ref_period_ms": 10, - "contamination_values": None, - }, - "amplitude_cutoff": { - "peak_sign": "neg", - "num_histogram_bins": 100, - "histogram_smoothing_value": 3, - "amplitudes_bins_min_ratio": 5, - }, - "amplitude_median": {"peak_sign": "neg"}, - "nearest_neighbor": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4}, - "nn_isolation": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, - "nn_noise_overlap": {"max_spikes": 10000, "min_spikes": 10, "n_neighbors": 4, "n_components": 10, "radius_um": 100}, -} -qm_metric_names = [ - "num_spikes", - "firing_rate", - "presence_ratio", - "snr", - "isi_violation", - "rp_violation", - "sliding_rp_violation", - "amplitude_cutoff", - "drift", - "isolation_distance", - "l_ratio", - "d_prime", -] - -postprocessing_params = dict( - sparsity=dict(method="radius", radius_um=100), - waveforms_deduplicate=dict( - ms_before=0.5, - ms_after=1.5, - max_spikes_per_unit=100, - return_scaled=False, - dtype=None, - precompute_template=("average",), - use_relative_path=True, - ), - waveforms=dict( - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, - return_scaled=True, - dtype=None, - precompute_template=("average", "std"), - use_relative_path=True, - ), - spike_amplitudes=dict( - peak_sign="neg", - return_scaled=True, - outputs="concatenated", - ), - similarity=dict(method="cosine_similarity"), - correlograms=dict( - window_ms=100.0, - bin_ms=2.0, - ), - isis=dict( - window_ms=100.0, - bin_ms=5.0, - ), - locations=dict(method="monopolar_triangulation"), - template_metrics=dict(upsampling_factor=10, sparsity=None), - principal_components=dict(n_components=5, mode="by_channel_local", whiten=True), - quality_metrics=dict( - qm_params=qm_params, - metric_names=qm_metric_names, - n_jobs=1 - ), -) - -curation_params = dict( - duplicate_threshold=0.9, - isi_violations_ratio_threshold=0.5, - presence_ratio_threshold=0.8, - amplitude_cutoff_threshold=0.1, -) - -visualization_params = dict( - timeseries=dict(n_snippets_per_segment=2, snippet_duration_s=0.5, skip=False), - drift=dict( - detection=dict(method="locally_exclusive", peak_sign="neg", detect_threshold=5, exclude_sweep_ms=0.1), - localization=dict(ms_before=0.1, ms_after=0.3, local_radius_um=100.0), - n_skip=30, - alpha=0.15, - vmin=-200, - vmax=0, - cmap="Greys_r", - figsize=(10, 10), - ), -) - -data_folder = Path("data/") -scratch_folder = Path("scratch/") -results_folder = Path("results/") - - def main( - run_identifier: str = None, - source: str = None, - source_data_type: str = None, - source_data_paths: dict = None, - recording_kwargs: dict = None, - output_destination: str = None, - output_path: str = None, - sorter_name: str = None, - sorter_kwargs: dict = None, - concatenate_segments: bool = None, + run_at: str, + run_identifier: str, + run_description: str, test_with_toy_recording: bool = None, test_with_subrecording: bool = None, test_subrecording_n_frames: int = None, log_to_file: bool = None, + source_name: str = None, + source_data_type: str = None, + source_data_paths: dict = None, + recording_kwargs: dict = None, + preprocessing_kwargs: dict = None, + sorter_kwargs: dict = None, + postprocessing_kwargs: dict = None, + curation_kwargs: dict = None, + visualization_kwargs: dict = None, + output_destination: str = None, + output_path: str = None ): """ This script should run in an ephemeral Docker container and will: 1. download a dataset with raw electrophysiology traces from a specfied location - 2. run a SpikeInterface pipeline on the raw traces - 3. save the results in a target S3 bucket - - The arguments for this script can be passsed as ENV variables: - - RUN_IDENTIFIER : Unique identifier for this run. - - SOURCE : Source of input data. Choose from: local, s3, dandi. - - SOURCE_DATA_PATHS : Dictionary with paths to source data. Keys are names of data files, values are urls. - - SOURCE_DATA_TYPE : Data type to be read. Choose from: nwb, spikeglx. - - RECORDING_KWARGS : SpikeInterface extractor keyword arguments, specific to chosen dataset type. - - OUTPUT_DESTINATION : Destination for saving results. Choose from: local, s3, dandi. - - OUTPUT_PATH : Path for saving results. - If S3, should be a valid S3 path, E.g. s3://... - If local, should be a valid local path, E.g. /data/results - If dandi, should be a valid Dandiset uri, E.g. https://dandiarchive.org/dandiset/000001 - - SORTERS_NAME : Name of sorter to run on source data. - - SORTERS_KWARGS : Parameters for the sorter, stored as a dictionary. - - CONCATENATE_SEGMENTS : If True, concatenates all segments of the recording into one. - - TEST_WITH_TOY_RECORDING : Runs script with a toy dataset. - - TEST_WITH_SUB_RECORDING : Runs script with the first 4 seconds of target dataset. - - TEST_SUB_RECORDING_N_FRAMES : Number of frames to use for sub-recording. - - LOG_TO_FILE : If True, logs will be saved to a file in /logs folder. + 2. run a SpikeInterface pipeline, including preprocessing, spike sorting, postprocessing and curation + 3. save the results in a target S3 bucket or DANDI archive If running this in any AWS service (e.g. Batch, ECS, EC2...) the access to other AWS services such as S3 storage can be given to the container by an IAM role. @@ -230,58 +78,70 @@ def main( If saving results to DANDI archive, or reading from embargoed dandisets, the following ENV variables should be present in the running container: - DANDI_API_KEY - DANDI_API_KEY_STAGING - """ - - # Order of priority for definition of running arguments: - # 1. passed by function - # 2. retrieved from ENV vars - # 3. default value - if not run_identifier: - run_identifier = os.environ.get("RUN_IDENTIFIER", datetime.now().strftime("%Y%m%d%H%M%S")) - if not source: - source = os.environ.get("SOURCE", None) - if source == "None": - source = None - if not source_data_paths: - source_data_paths = eval(os.environ.get("SOURCE_DATA_PATHS", "{}")) - if not source_data_type: - source_data_type = os.environ.get("SOURCE_DATA_TYPE", "nwb") - if not recording_kwargs: - recording_kwargs = ast.literal_eval(os.environ.get("RECORDING_KWARGS", "{}")) - if not output_destination: - output_destination = os.environ.get("OUTPUT_DESTINATION", "s3") - if not output_path: - output_path = os.environ.get("OUTPUT_PATH", None) - if output_path == "None": - output_path = None - if not sorter_name: - sorter_name = os.environ.get("SORTER_NAME", "kilosort2.5") - if not sorter_kwargs: - sorter_kwargs = eval(os.environ.get("SORTER_KWARGS", "{}")) - if not concatenate_segments: - concatenate_segments_str = os.environ.get("CONCATENATE_SEGMENTS", "False") - concatenate_segments = True if concatenate_segments_str == "True" else False - if test_with_toy_recording is None: - test_with_toy_recording = os.environ.get("TEST_WITH_TOY_RECORDING", "False").lower() in ("true", "1", "t") - if test_with_subrecording is None: - test_with_subrecording = os.environ.get("TEST_WITH_SUB_RECORDING", "False").lower() in ("true", "1", "t") - if not test_subrecording_n_frames: - test_subrecording_n_frames = int(os.environ.get("TEST_SUBRECORDING_N_FRAMES", 300000)) - if log_to_file is None: - log_to_file = os.environ.get("LOG_TO_FILE", "False").lower() in ("true", "1", "t") + Parameters + ---------- + run_at : str + Where to run the sorting job. Choose from: aws, local. + run_identifier : str + Unique identifier for this run. + run_description : str + Description for this run. + test_with_toy_recording : bool + If True, runs script with a toy dataset. + test_with_subrecording : bool + If True, runs script with a subrecording of the source dataset. + test_subrecording_n_frames : int + Number of frames to use for sub-recording. + log_to_file : bool + If True, logs will be saved to a file in /logs folder. + source_name : str + Source of input data. Choose from: local, s3, dandi. + source_data_type : str + Data type to be read. Choose from: nwb, spikeglx. + source_data_paths : dict + Dictionary with paths to source data. Keys are names of data files, values are urls or paths. + recording_kwargs : dict + SpikeInterface recording keyword arguments, specific to chosen dataset type. + preprocessing_kwargs : dict + SpikeInterface preprocessing keyword arguments. + sorter_kwargs : dict + SpikeInterface sorter keyword arguments. + postprocessing_kwargs : dict + SpikeInterface postprocessing keyword arguments. + curation_kwargs : dict + SpikeInterface curation keyword arguments. + visualization_kwargs : dict + SpikeInterface visualization keyword arguments. + output_destination : str + Destination for saving results. Choose from: local, s3, dandi. + output_path : str + Path for saving results. + If S3, should be a valid S3 path, E.g. s3://... + If local, should be a valid local path, E.g. /data/results + If dandi, should be a valid Dandiset uri, E.g. https://dandiarchive.org/dandiset/000001 + """ # Set up logging logger = make_logger(run_identifier=run_identifier, log_to_file=log_to_file) filterwarnings(action="ignore", message="No cached namespaces found in .*") filterwarnings(action="ignore", message="Ignoring cached namespace .*") - # SET DEFAULT JOB KWARGS - si.set_global_job_kwargs(**job_kwargs) + # Set SpikeInterface global job kwargs + si.set_global_job_kwargs( + n_jobs=os.cpu_count(), + chunk_duration="1s", + progress_bar=False + ) # Create folders + data_folder = Path("data/") data_folder.mkdir(exist_ok=True) + + scratch_folder = Path("scratch/") scratch_folder.mkdir(exist_ok=True) + + results_folder = Path("results/") results_folder.mkdir(exist_ok=True) tmp_folder = scratch_folder / "tmp" if tmp_folder.is_dir(): @@ -289,9 +149,9 @@ def main( tmp_folder.mkdir() # Checks - if source not in ["local", "s3", "dandi"]: - logger.error(f"Source {source} not supported. Choose from: local, s3, dandi.") - raise ValueError(f"Source {source} not supported. Choose from: local, s3, dandi.") + if source_name not in ["local", "s3", "dandi"]: + logger.error(f"Source {source_name} not supported. Choose from: local, s3, dandi.") + raise ValueError(f"Source {source_name} not supported. Choose from: local, s3, dandi.") # TODO: here we could leverage spikeinterface and add more options if source_data_type not in ["nwb", "spikeglx"]: @@ -324,7 +184,7 @@ def main( recording_name = "toy" # Load data from S3 - elif source == "s3": + elif source_name == "s3": for k, data_url in source_data_paths.items(): if not data_url.startswith("s3://"): logger.error(f"Data url {data_url} is not a valid S3 path. E.g. s3://...") @@ -347,7 +207,7 @@ def main( recording = se.read_nwb_recording(file_path=f"/data/{file_name}", **recording_kwargs) recording_name = "recording_on_s3" - elif source == "dandi": + elif source_name == "dandi": dandiset_s3_file_url = source_data_paths["file"] if not dandiset_s3_file_url.startswith("https://dandiarchive"): raise Exception( @@ -708,7 +568,67 @@ def main( if __name__ == "__main__": - main() + # Get run kwargs from ENV variables + run_kwargs = json.loads(os.environ.get("SI_RUN_KWARGS", "{}")) + run_at = validate_not_none(run_kwargs, "run_at") + run_identifier = run_kwargs.get("run_identifier", datetime.now().strftime("%Y%m%d%H%M%S")) + run_description = run_kwargs.get("run_description", "") + test_with_toy_recording = run_kwargs.get("test_with_toy_recording", "False").lower() in ("true", "1", "t") + test_with_subrecording = run_kwargs.get("test_with_subrecording", "False").lower() in ("true", "1", "t") + test_subrecording_n_frames = int(run_kwargs.get("test_subrecording_n_frames", 30000)) + log_to_file = run_kwargs.get("log_to_file", "False").lower() in ("true", "1", "t") + + # Get source data kwargs from ENV variables + source_data_kwargs = json.loads(os.environ.get("SI_SOURCE_DATA_KWARGS", "{}")) + source_name = validate_not_none(source_data_kwargs, "source_name") + source_data_type = validate_not_none(source_data_kwargs, "source_data_type") + source_data_paths = validate_not_none(source_data_kwargs, "source_data_paths") + + # Get recording kwargs from ENV variables + recording_kwargs = json.loads(os.environ.get("SI_RECORDING_KWARGS", "{}")) + + # Get preprocessing kwargs from ENV variables + preprocessing_kwargs = json.loads(os.environ.get("SI_PREPROCESSING_KWARGS", "{}")) + + # Get sorter kwargs from ENV variables + sorter_kwargs = json.loads(os.environ.get("SI_SORTER_KWARGS", "{}")) + + # Get postprocessing kwargs from ENV variables + postprocessing_kwargs = json.loads(os.environ.get("SI_POSTPROCESSING_KWARGS", "{}")) + + # Get curation kwargs from ENV variables + curation_kwargs = json.loads(os.environ.get("SI_CURATION_KWARGS", "{}")) + + # Get visualization kwargs from ENV variables + visualization_kwargs = json.loads(os.environ.get("SI_VISUALIZATION_KWARGS", "{}")) + + # Get output kwargs from ENV variables + output_kwargs = json.loads(os.environ.get("SI_OUTPUT_KWARGS", "{}")) + output_destination = validate_not_none(output_kwargs, "output_destination") + output_path = validate_not_none(output_kwargs, "output_path") + + # Run main function + main( + run_at=run_at, + run_identifier=run_identifier, + run_description=run_description, + test_with_toy_recording=test_with_toy_recording, + test_with_subrecording=test_with_subrecording, + test_subrecording_n_frames=test_subrecording_n_frames, + log_to_file=log_to_file, + source_name=source_name, + source_data_type=source_data_type, + source_data_paths=source_data_paths, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, + output_destination=output_destination, + output_path=output_path, + ) + # Known issues: # diff --git a/containers/source_format.py b/containers/source_format.py deleted file mode 100644 index 216717f..0000000 --- a/containers/source_format.py +++ /dev/null @@ -1,23 +0,0 @@ - - -source_data = { - "source": "dandi", # or "s3" - "source_data_type": "nwb", # or "spikeglx" - "source_data_paths": { - "file": "https://dandiarchive.org/dandiset/000003/0.210813.1807" - }, - "recording_kwargs": { - "electrical_series_name": "ElectricalSeries", - } -} - - -source_data_2 = { - "source": "s3", - "source_data_type": "spikeglx", - "source_data_paths": { - "file_bin": "s3://bucket/path/to/file.ap.bin", - "file_meta": "s3://bucket/path/to/file2.ap.meta", - }, - "recording_kwargs": {} -} \ No newline at end of file diff --git a/containers/utils.py b/containers/utils.py index 9598648..60de1e0 100644 --- a/containers/utils.py +++ b/containers/utils.py @@ -56,6 +56,13 @@ def make_logger(run_identifier: str, log_to_file: bool): return logger +def validate_not_none(d, k): + v = d.get(k, None) + if v is None: + raise ValueError(f"{k} not specified.") + return v + + def download_file_from_url(url): # ref: https://stackoverflow.com/a/39217788/11483674 local_filename = "/data/filename.nwb" diff --git a/rest/models/sorting.py b/rest/models/sorting.py index bed097b..e78942f 100644 --- a/rest/models/sorting.py +++ b/rest/models/sorting.py @@ -34,8 +34,8 @@ class SourceDataType(str, Enum): class SourceDataKwargs(BaseModel): source_name: SourceName = Field(..., description="Source of input data. Choose from: s3, dandi, local.") - source_data_paths: Dict[str, str] = Field(..., description="Dictionary with paths to source data. Keys are names of data files, values are urls.") source_data_type: SourceDataType = Field(..., description="Type of input data. Choose from: nwb, spikeglx.") + source_data_paths: Dict[str, str] = Field(..., description="Dictionary with paths to source data. Keys are names of data files, values are urls.") # ------------------------------ From 38a2b5a5fffe056c23f80f5625f8c7fe4ffaefca Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 12:57:41 +0200 Subject: [PATCH 4/9] wip refactoring main.py --- containers/main.py | 113 ++++++++++++++++++++--------------- rest/clients/local_docker.py | 3 + rest/models/sorting.py | 9 ++- rest/routes/sorting.py | 9 +-- 4 files changed, 78 insertions(+), 56 deletions(-) diff --git a/containers/main.py b/containers/main.py index b0f4b43..07c4f5d 100644 --- a/containers/main.py +++ b/containers/main.py @@ -127,33 +127,12 @@ def main( filterwarnings(action="ignore", message="No cached namespaces found in .*") filterwarnings(action="ignore", message="Ignoring cached namespace .*") - # Set SpikeInterface global job kwargs - si.set_global_job_kwargs( - n_jobs=os.cpu_count(), - chunk_duration="1s", - progress_bar=False - ) - - # Create folders - data_folder = Path("data/") - data_folder.mkdir(exist_ok=True) - - scratch_folder = Path("scratch/") - scratch_folder.mkdir(exist_ok=True) - - results_folder = Path("results/") - results_folder.mkdir(exist_ok=True) - tmp_folder = scratch_folder / "tmp" - if tmp_folder.is_dir(): - shutil.rmtree(tmp_folder) - tmp_folder.mkdir() - # Checks if source_name not in ["local", "s3", "dandi"]: logger.error(f"Source {source_name} not supported. Choose from: local, s3, dandi.") raise ValueError(f"Source {source_name} not supported. Choose from: local, s3, dandi.") - # TODO: here we could leverage spikeinterface and add more options + # TODO: here we could eventually leverage spikeinterface and add more options if source_data_type not in ["nwb", "spikeglx"]: logger.error(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.") raise ValueError(f"Data type {source_data_type} not supported. Choose from: nwb, spikeglx.") @@ -174,7 +153,28 @@ def main( output_s3_bucket = output_path_parsed.split("/")[0] output_s3_bucket_folder = "/".join(output_path_parsed.split("/")[1:]) - s3_client = boto3.client("s3") + # Set SpikeInterface global job kwargs + si.set_global_job_kwargs( + n_jobs=os.cpu_count(), + chunk_duration="1s", + progress_bar=False + ) + + # Create SpikeInterface folders + data_folder = Path("data/") + data_folder.mkdir(exist_ok=True) + scratch_folder = Path("scratch/") + scratch_folder.mkdir(exist_ok=True) + results_folder = Path("results/") + results_folder.mkdir(exist_ok=True) + tmp_folder = scratch_folder / "tmp" + if tmp_folder.is_dir(): + shutil.rmtree(tmp_folder) + tmp_folder.mkdir() + + # S3 client + if source_name == "s3" or output_destination == "s3": + s3_client = boto3.client("s3") # Test with toy recording if test_with_toy_recording: @@ -198,7 +198,6 @@ def main( bucket_name=bucket_name, file_path=file_path, ) - logger.info("Reading recording...") # E.g.: se.read_spikeglx(folder_path="/data", stream_id="imec.ap") if source_data_type == "spikeglx": @@ -207,17 +206,16 @@ def main( recording = se.read_nwb_recording(file_path=f"/data/{file_name}", **recording_kwargs) recording_name = "recording_on_s3" + # Load data from DANDI archive elif source_name == "dandi": dandiset_s3_file_url = source_data_paths["file"] if not dandiset_s3_file_url.startswith("https://dandiarchive"): raise Exception( f"DANDISET_S3_FILE_URL should be a valid Dandiset S3 url. Value received was: {dandiset_s3_file_url}" ) - if not test_with_subrecording: logger.info(f"Downloading dataset: {dandiset_s3_file_url}") download_file_from_url(dandiset_s3_file_url) - logger.info("Reading recording from NWB...") recording = se.read_nwb_recording(file_path="/data/filename.nwb", **recording_kwargs) else: @@ -225,12 +223,18 @@ def main( recording = se.read_nwb_recording(file_path=dandiset_s3_file_url, stream_mode="fsspec", **recording_kwargs) recording_name = "recording_on_dandi" + # TODO - Load data from local files + elif source_name == "local": + pass + + # Run with subrecording if test_with_subrecording: n_frames = int(min(test_subrecording_n_frames, recording.get_num_frames())) recording = recording.frame_slice(start_frame=0, end_frame=n_frames) # ------------------------------------------------------------------------------------ # Preprocessing + # ------------------------------------------------------------------------------------ logger.info("Starting preprocessing...") preprocessing_notes = "" preprocessed_folder = tmp_folder / "preprocessed" @@ -238,13 +242,13 @@ def main( logger.info(f"\tDuration: {np.round(recording.get_total_duration(), 2)} s") if "inter_sample_shift" in recording.get_property_keys(): - recording_ps_full = spre.phase_shift(recording, **preprocessing_params["phase_shift"]) + recording_ps_full = spre.phase_shift(recording, **preprocessing_kwargs["phase_shift"]) else: recording_ps_full = recording - recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_params["highpass_filter"]) + recording_hp_full = spre.highpass_filter(recording_ps_full, **preprocessing_kwargs["highpass_filter"]) # IBL bad channel detection - _, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_params["detect_bad_channels"]) + _, channel_labels = spre.detect_bad_channels(recording_hp_full, **preprocessing_kwargs["detect_bad_channels"]) dead_channel_mask = channel_labels == "dead" noise_channel_mask = channel_labels == "noise" out_channel_mask = channel_labels == "out" @@ -257,7 +261,7 @@ def main( out_channel_ids = recording_hp_full.channel_ids[out_channel_mask] all_bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids, out_channel_ids)) - max_bad_channel_fraction_to_remove = preprocessing_params["max_bad_channel_fraction_to_remove"] + max_bad_channel_fraction_to_remove = preprocessing_kwargs["max_bad_channel_fraction_to_remove"] if len(all_bad_channel_ids) >= int(max_bad_channel_fraction_to_remove * recording.get_num_channels()): logger.info( f"\tMore than {max_bad_channel_fraction_to_remove * 100}% bad channels ({len(all_bad_channel_ids)}). " @@ -267,28 +271,28 @@ def main( # in this case, we don't bother sorting return else: - if preprocessing_params["remove_out_channels"]: + if preprocessing_kwargs["remove_out_channels"]: logger.info(f"\tRemoving {len(out_channel_ids)} out channels") recording_rm_out = recording_hp_full.remove_channels(out_channel_ids) preprocessing_notes += f"\n- Removed {len(out_channel_ids)} outside of the brain." else: recording_rm_out = recording_hp_full - recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_params["common_reference"]) + recording_processed_cmr = spre.common_reference(recording_rm_out, **preprocessing_kwargs["common_reference"]) bad_channel_ids = np.concatenate((dead_channel_ids, noise_channel_ids)) recording_interp = spre.interpolate_bad_channels(recording_rm_out, bad_channel_ids) recording_hp_spatial = spre.highpass_spatial_filter( - recording_interp, **preprocessing_params["highpass_spatial_filter"] + recording_interp, **preprocessing_kwargs["highpass_spatial_filter"] ) - preproc_strategy = preprocessing_params["preprocessing_strategy"] + preproc_strategy = preprocessing_kwargs["preprocessing_strategy"] if preproc_strategy == "cmr": recording_processed = recording_processed_cmr else: recording_processed = recording_hp_spatial - if preprocessing_params["remove_bad_channels"]: + if preprocessing_kwargs["remove_bad_channels"]: logger.info(f"\tRemoving {len(bad_channel_ids)} channels after {preproc_strategy} preprocessing") recording_processed = recording_processed.remove_channels(bad_channel_ids) preprocessing_notes += f"\n- Removed {len(bad_channel_ids)} bad channels after preprocessing.\n" @@ -300,6 +304,8 @@ def main( # ------------------------------------------------------------------------------------ # Spike Sorting + # ------------------------------------------------------------------------------------ + sorter_name = sorter_kwargs["sorter_name"] logger.info(f"\n\nStarting spike sorting with {sorter_name}") spikesorting_notes = "" sorting_params = None @@ -313,7 +319,7 @@ def main( if recording_processed.get_num_segments() > 1: recording_processed = si.concatenate_recordings([recording_processed]) - # run ks2.5 + # Run sorter try: sorting = ss.run_sorter( sorter_name, @@ -334,6 +340,7 @@ def main( # remove empty units sorting = sorting.remove_empty_units() + # remove spikes beyond num_samples (if any) sorting = sc.remove_excess_spikes(sorting=sorting, recording=recording_processed) logger.info(f"\tSorting output without empty units: {sorting}") @@ -353,6 +360,7 @@ def main( # ------------------------------------------------------------------------------------ # Postprocessing + # ------------------------------------------------------------------------------------ logger.info("\n\Starting postprocessing...") postprocessing_notes = "" t_postprocessing_start = time.perf_counter() @@ -360,7 +368,7 @@ def main( # first extract some raw waveforms in memory to deduplicate based on peak alignment wf_dedup_folder = tmp_folder / "postprocessed" / recording_name we_raw = si.extract_waveforms( - recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_params["waveforms_deduplicate"] + recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_kwargs["waveforms_deduplicate"] ) # de-duplication sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_params["duplicate_threshold"]) @@ -389,32 +397,34 @@ def main( sparsity=sparsity, sparse=True, overwrite=True, - **postprocessing_params["waveforms"], + **postprocessing_kwargs["waveforms"], ) logger.info("\tComputing spike amplitides") - spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_params["spike_amplitudes"]) + spike_amplitudes = spost.compute_spike_amplitudes(we, **postprocessing_kwargs["spike_amplitudes"]) logger.info("\tComputing unit locations") - unit_locations = spost.compute_unit_locations(we, **postprocessing_params["locations"]) + unit_locations = spost.compute_unit_locations(we, **postprocessing_kwargs["locations"]) logger.info("\tComputing spike locations") - spike_locations = spost.compute_spike_locations(we, **postprocessing_params["locations"]) + spike_locations = spost.compute_spike_locations(we, **postprocessing_kwargs["locations"]) logger.info("\tComputing correlograms") - ccg, bins = spost.compute_correlograms(we, **postprocessing_params["correlograms"]) + ccg, bins = spost.compute_correlograms(we, **postprocessing_kwargs["correlograms"]) logger.info("\tComputing ISI histograms") - isi, bins = spost.compute_isi_histograms(we, **postprocessing_params["isis"]) + isi, bins = spost.compute_isi_histograms(we, **postprocessing_kwargs["isis"]) logger.info("\tComputing template similarity") - sim = spost.compute_template_similarity(we, **postprocessing_params["similarity"]) + sim = spost.compute_template_similarity(we, **postprocessing_kwargs["similarity"]) logger.info("\tComputing template metrics") - tm = spost.compute_template_metrics(we, **postprocessing_params["template_metrics"]) + tm = spost.compute_template_metrics(we, **postprocessing_kwargs["template_metrics"]) logger.info("\tComputing PCA") - pca = spost.compute_principal_components(we, **postprocessing_params["principal_components"]) + pca = spost.compute_principal_components(we, **postprocessing_kwargs["principal_components"]) logger.info("\tComputing quality metrics") - qm = sqm.compute_quality_metrics(we, **postprocessing_params["quality_metrics"]) + qm = sqm.compute_quality_metrics(we, **postprocessing_kwargs["quality_metrics"]) t_postprocessing_end = time.perf_counter() elapsed_time_postprocessing = np.round(t_postprocessing_end - t_postprocessing_start, 2) logger.info(f"Postprocessing time: {elapsed_time_postprocessing}s") - ###### CURATION ############## + # ------------------------------------------------------------------------------------ + # Curation + # ------------------------------------------------------------------------------------ logger.info("\n\Starting curation...") curation_notes = "" t_curation_start = time.perf_counter() @@ -452,10 +462,15 @@ def main( elapsed_time_curation = np.round(t_curation_end - t_curation_start, 2) logger.info(f"Curation time: {elapsed_time_curation}s") + # ------------------------------------------------------------------------------------ # TODO: Visualization with FIGURL (needs credentials) + # ------------------------------------------------------------------------------------ + + # ------------------------------------------------------------------------------------ # Conversion and upload + # ------------------------------------------------------------------------------------ logger.info("Writing sorting results to NWB...") metadata = { "NWBFile": { @@ -603,7 +618,7 @@ def main( visualization_kwargs = json.loads(os.environ.get("SI_VISUALIZATION_KWARGS", "{}")) # Get output kwargs from ENV variables - output_kwargs = json.loads(os.environ.get("SI_OUTPUT_KWARGS", "{}")) + output_kwargs = json.loads(os.environ.get("SI_OUTPUT_DATA_KWARGS", "{}")) output_destination = validate_not_none(output_kwargs, "output_destination") output_path = validate_not_none(output_kwargs, "output_path") diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py index ed738b0..fbd2042 100644 --- a/rest/clients/local_docker.py +++ b/rest/clients/local_docker.py @@ -11,6 +11,7 @@ PostprocessingKwargs, CurationKwargs, VisualizationKwargs, + OutputDataKwargs ) @@ -30,6 +31,7 @@ def run_sorting( postprocessing_kwargs: PostprocessingKwargs, curation_kwargs: CurationKwargs, visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, ) -> None: # Pass kwargs as environment variables to the container env_vars = dict( @@ -41,6 +43,7 @@ def run_sorting( SI_POSTPROCESSING_KWARGS=postprocessing_kwargs.json(), SI_CURATION_KWARGS=curation_kwargs.json(), SI_VISUALIZATION_KWARGS=visualization_kwargs.json(), + SI_OUTPUT_DATA_KWARGS=output_data_kwargs.json(), ) # Local volumes to mount diff --git a/rest/models/sorting.py b/rest/models/sorting.py index e78942f..33f71e7 100644 --- a/rest/models/sorting.py +++ b/rest/models/sorting.py @@ -1,6 +1,7 @@ from pydantic import BaseModel, Field, Extra from typing import Optional, Dict, List, Union, Tuple from enum import Enum +from datetime import datetime # ------------------------------ @@ -10,10 +11,13 @@ class RunAt(str, Enum): aws = "aws" local = "local" +def default_run_identifier(): + return datetime.now().strftime("%Y%m%d-%H%M%S") + class RunKwargs(BaseModel): run_at: RunAt = Field(..., description="Where to run the sorting job. Choose from: aws, local.") - run_identifier: str = Field(..., description="Unique identifier for the run.") - run_description: str = Field(..., description="Description of the run.") + run_identifier: str = Field(default_factory=default_run_identifier, description="Unique identifier for the run.") + run_description: str = Field(default="", description="Description of the run.") test_with_toy_recording: bool = Field(default=False, description="Whether to test with a toy recording.") test_with_subrecording: bool = Field(default=False, description="Whether to test with a subrecording.") test_subrecording_n_frames: Optional[int] = Field(default=30000, description="Number of frames to use for the subrecording.") @@ -278,4 +282,3 @@ class VisualizationKwargs(BaseModel): timeseries: Timeseries drift: Drift - diff --git a/rest/routes/sorting.py b/rest/routes/sorting.py index 188724c..30f8dac 100644 --- a/rest/routes/sorting.py +++ b/rest/routes/sorting.py @@ -17,6 +17,7 @@ PostprocessingKwargs, CurationKwargs, VisualizationKwargs, + OutputDataKwargs ) @@ -32,6 +33,7 @@ def sorting_background_task( postprocessing_kwargs: PostprocessingKwargs, curation_kwargs: CurationKwargs, visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, ): # Run sorting and update db entry status db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) @@ -49,6 +51,7 @@ def sorting_background_task( postprocessing_kwargs=postprocessing_kwargs, curation_kwargs=curation_kwargs, visualization_kwargs=visualization_kwargs, + output_data_kwargs=output_data_kwargs, ) elif run_at == "aws": # TODO: Implement this @@ -77,12 +80,9 @@ async def route_run_sorting( postprocessing_kwargs: PostprocessingKwargs, curation_kwargs: CurationKwargs, visualization_kwargs: VisualizationKwargs, + output_data_kwargs: OutputDataKwargs, background_tasks: BackgroundTasks ) -> JSONResponse: - if not run_kwargs.run_identifier: - run_identifier = datetime.now().strftime("%Y%m%d%H%M%S") - else: - run_identifier = run_kwargs.run_identifier try: # Create Database entries db_client = DatabaseClient(connection_string=settings.DB_CONNECTION_STRING) @@ -121,6 +121,7 @@ async def route_run_sorting( postprocessing_kwargs=postprocessing_kwargs, curation_kwargs=curation_kwargs, visualization_kwargs=visualization_kwargs, + output_data_kwargs=output_data_kwargs, ) except Exception as e: From b9d96706f83f977d564782061a44aa9257ed16e9 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 13:02:49 +0200 Subject: [PATCH 5/9] wip refactor main.py --- containers/main.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/containers/main.py b/containers/main.py index 07c4f5d..8fbd742 100644 --- a/containers/main.py +++ b/containers/main.py @@ -371,7 +371,7 @@ def main( recording_processed, sorting, folder=wf_dedup_folder, **postprocessing_kwargs["waveforms_deduplicate"] ) # de-duplication - sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_params["duplicate_threshold"]) + sorting_deduplicated = sc.remove_redundant_units(we_raw, duplicate_threshold=curation_kwargs["duplicate_threshold"]) logger.info( f"\tNumber of original units: {len(we_raw.sorting.unit_ids)} -- Number of units after de-duplication: {len(sorting_deduplicated.unit_ids)}" ) @@ -380,7 +380,7 @@ def main( ) deduplicated_unit_ids = sorting_deduplicated.unit_ids # use existing deduplicated waveforms to compute sparsity - sparsity_raw = si.compute_sparsity(we_raw, **sparsity_params) + sparsity_raw = si.compute_sparsity(we_raw, **postprocessing_kwargs["sparsity"]) sparsity_mask = sparsity_raw.mask[sorting.ids_to_indices(deduplicated_unit_ids), :] sparsity = si.ChannelSparsity(mask=sparsity_mask, unit_ids=deduplicated_unit_ids, channel_ids=recording.channel_ids) shutil.rmtree(wf_dedup_folder) @@ -430,9 +430,9 @@ def main( t_curation_start = time.perf_counter() # curation query - isi_violations_ratio_thr = curation_params["isi_violations_ratio_threshold"] - presence_ratio_thr = curation_params["presence_ratio_threshold"] - amplitude_cutoff_thr = curation_params["amplitude_cutoff_threshold"] + isi_violations_ratio_thr = curation_kwargs["isi_violations_ratio_threshold"] + presence_ratio_thr = curation_kwargs["presence_ratio_threshold"] + amplitude_cutoff_thr = curation_kwargs["amplitude_cutoff_threshold"] curation_query = f"isi_violations_ratio < {isi_violations_ratio_thr} and presence_ratio > {presence_ratio_thr} and amplitude_cutoff < {amplitude_cutoff_thr}" logger.info(f"Curation query: {curation_query}") @@ -489,7 +489,7 @@ def main( results_nwb_folder.mkdir(parents=True, exist_ok=True) output_nwbfile_path = results_nwb_folder / f"{run_identifier}.nwb" - # TODO: Condider writing waveforms instead of sorting + # TODO: Consider writing waveforms instead of sorting # add sorting properties # unit locations sorting.set_property("unit_locations", unit_locations) From 44f9fff31add1ecc2f500208033fee0f04f57f7e Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 15:04:51 +0200 Subject: [PATCH 6/9] wip refactor main.py --- containers/Dockerfile.ks2_5 | 2 +- containers/README.md | 3 +- containers/main.py | 68 ++++++++++++------- ...compose-dev.yml => docker-compose-dev.yaml | 0 docker-compose.yml => docker-compose.yaml | 0 rest/clients/local_docker.py | 13 +++- 6 files changed, 58 insertions(+), 28 deletions(-) rename docker-compose-dev.yml => docker-compose-dev.yaml (100%) rename docker-compose.yml => docker-compose.yaml (100%) diff --git a/containers/Dockerfile.ks2_5 b/containers/Dockerfile.ks2_5 index 602d67b..90bb79b 100644 --- a/containers/Dockerfile.ks2_5 +++ b/containers/Dockerfile.ks2_5 @@ -1,5 +1,5 @@ # Spike sorters image -FROM spikeinterface/kilosort2_5-compiled-base:0.2.0 as ks25base +FROM spikeinterface/kilosort2_5-compiled-base:0.2.0 # # NVIDIA-ready Image # FROM nvidia/cuda:11.6.2-base-ubuntu20.04 diff --git a/containers/README.md b/containers/README.md index c85cdd0..0d22f95 100644 --- a/containers/README.md +++ b/containers/README.md @@ -16,7 +16,8 @@ Basic infrastructure makes use of the following AWS services: Build docker image: ```bash -$ DOCKER_BUILDKIT=1 docker build -t -f . +$ DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/si-sorting-ks25:latest -f Dockerfile.ks2_5 . +$ DOCKER_BUILDKIT=1 docker build -t ghcr.io/catalystneuro/si-sorting-ks3:latest -f Dockerfile.ks3 . ``` Run locally: diff --git a/containers/main.py b/containers/main.py index 8fbd742..ec1103a 100644 --- a/containers/main.py +++ b/containers/main.py @@ -588,10 +588,10 @@ def main( run_at = validate_not_none(run_kwargs, "run_at") run_identifier = run_kwargs.get("run_identifier", datetime.now().strftime("%Y%m%d%H%M%S")) run_description = run_kwargs.get("run_description", "") - test_with_toy_recording = run_kwargs.get("test_with_toy_recording", "False").lower() in ("true", "1", "t") - test_with_subrecording = run_kwargs.get("test_with_subrecording", "False").lower() in ("true", "1", "t") + test_with_toy_recording = run_kwargs.get("test_with_toy_recording", "False") + test_with_subrecording = run_kwargs.get("test_with_subrecording", "False") test_subrecording_n_frames = int(run_kwargs.get("test_subrecording_n_frames", 30000)) - log_to_file = run_kwargs.get("log_to_file", "False").lower() in ("true", "1", "t") + log_to_file = run_kwargs.get("log_to_file", "False") # Get source data kwargs from ENV variables source_data_kwargs = json.loads(os.environ.get("SI_SOURCE_DATA_KWARGS", "{}")) @@ -622,27 +622,47 @@ def main( output_destination = validate_not_none(output_kwargs, "output_destination") output_path = validate_not_none(output_kwargs, "output_path") - # Run main function - main( - run_at=run_at, - run_identifier=run_identifier, - run_description=run_description, - test_with_toy_recording=test_with_toy_recording, - test_with_subrecording=test_with_subrecording, - test_subrecording_n_frames=test_subrecording_n_frames, - log_to_file=log_to_file, - source_name=source_name, - source_data_type=source_data_type, - source_data_paths=source_data_paths, - recording_kwargs=recording_kwargs, - preprocessing_kwargs=preprocessing_kwargs, - sorter_kwargs=sorter_kwargs, - postprocessing_kwargs=postprocessing_kwargs, - curation_kwargs=curation_kwargs, - visualization_kwargs=visualization_kwargs, - output_destination=output_destination, - output_path=output_path, - ) + # # Run main function + # main( + # run_at=run_at, + # run_identifier=run_identifier, + # run_description=run_description, + # test_with_toy_recording=test_with_toy_recording, + # test_with_subrecording=test_with_subrecording, + # test_subrecording_n_frames=test_subrecording_n_frames, + # log_to_file=log_to_file, + # source_name=source_name, + # source_data_type=source_data_type, + # source_data_paths=source_data_paths, + # recording_kwargs=recording_kwargs, + # preprocessing_kwargs=preprocessing_kwargs, + # sorter_kwargs=sorter_kwargs, + # postprocessing_kwargs=postprocessing_kwargs, + # curation_kwargs=curation_kwargs, + # visualization_kwargs=visualization_kwargs, + # output_destination=output_destination, + # output_path=output_path, + # ) + + print("\nRun at: ", run_at) + print("\nRun identifier: ", run_identifier) + print("\nRun description: ", run_description) + print("\nTest with toy recording: ", test_with_toy_recording) + print("\nTest with subrecording: ", test_with_subrecording) + print("\nTest subrecording n frames: ", test_subrecording_n_frames) + print("\nLog to file: ", log_to_file) + print("\nSource name: ", source_name) + print("\nSource data type: ", source_data_type) + print("\nSource data paths: ", source_data_paths) + print("\nRecording kwargs: ", recording_kwargs) + print("\nPreprocessing kwargs: ", preprocessing_kwargs) + print("\nSorter kwargs: ", sorter_kwargs) + print("\nPostprocessing kwargs: ", postprocessing_kwargs) + print("\nCuration kwargs: ", curation_kwargs) + print("\nVisualization kwargs: ", visualization_kwargs) + print("\nOutput destination: ", output_destination) + print("\nOutput path: ", output_path) + # Known issues: diff --git a/docker-compose-dev.yml b/docker-compose-dev.yaml similarity index 100% rename from docker-compose-dev.yml rename to docker-compose-dev.yaml diff --git a/docker-compose.yml b/docker-compose.yaml similarity index 100% rename from docker-compose.yml rename to docker-compose.yaml diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py index fbd2042..c8216f7 100644 --- a/rest/clients/local_docker.py +++ b/rest/clients/local_docker.py @@ -15,6 +15,15 @@ ) +map_sorter_to_image = { + "kilosort2": "ghcr.io/catalystneuro/si-sorting-ks2:latest", + "kilosort25": "ghcr.io/catalystneuro/si-sorting-ks25:latest", + "kilosort3": "ghcr.io/catalystneuro/si-sorting-ks3:latest", + "ironclust": "ghcr.io/catalystneuro/si-sorting-ironclust:latest", + "spykingcircus": "ghcr.io/catalystneuro/si-sorting-spyking-circus:latest", +} + + class LocalDockerClient: def __init__(self, base_url: str = "tcp://docker-proxy:2375"): @@ -57,8 +66,8 @@ def run_sorting( container = self.client.containers.run( name=f'si-sorting-run-{run_kwargs.run_identifier}', - image='python:slim', - command=['python', '-c', 'import os; print(os.environ.get("SI_RUN_KWARGS"))'], + image=map_sorter_to_image[sorter_kwargs.sorter_name], + command=['python', 'main.py'], detach=True, environment=env_vars, volumes=volumes, From 70fe9a838e59afa0487b4a863931dd0797f78715 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 15:12:22 +0200 Subject: [PATCH 7/9] uuid container name --- rest/clients/local_docker.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rest/clients/local_docker.py b/rest/clients/local_docker.py index c8216f7..f2d97fa 100644 --- a/rest/clients/local_docker.py +++ b/rest/clients/local_docker.py @@ -1,5 +1,6 @@ from pathlib import Path import docker +import uuid from ..core.logger import logger from ..models.sorting import ( @@ -65,7 +66,7 @@ def run_sorting( } container = self.client.containers.run( - name=f'si-sorting-run-{run_kwargs.run_identifier}', + name=f'si-sorting-run-{run_kwargs.run_identifier}-{uuid.uuid4().hex[:6]}', image=map_sorter_to_image[sorter_kwargs.sorter_name], command=['python', 'main.py'], detach=True, From 177dd4311b7dca2dad6465d7cbea7d4725e54bf8 Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 15:13:09 +0200 Subject: [PATCH 8/9] gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a9d0917..8801548 100644 --- a/.gitignore +++ b/.gitignore @@ -159,4 +159,5 @@ yarn-error.log* # Miscellaneous *.fuse_hidden* results/ +logs/ *.nwb \ No newline at end of file From 38be68b9c82caf0c9aaf98c69d2d27bd5f25dfdf Mon Sep 17 00:00:00 2001 From: luiz Date: Fri, 13 Oct 2023 15:29:33 +0200 Subject: [PATCH 9/9] refactor main.py --- containers/main.py | 45 +++++++++++++++++++++--------------------- rest/models/sorting.py | 4 ++-- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/containers/main.py b/containers/main.py index ec1103a..79d00b5 100644 --- a/containers/main.py +++ b/containers/main.py @@ -622,28 +622,7 @@ def main( output_destination = validate_not_none(output_kwargs, "output_destination") output_path = validate_not_none(output_kwargs, "output_path") - # # Run main function - # main( - # run_at=run_at, - # run_identifier=run_identifier, - # run_description=run_description, - # test_with_toy_recording=test_with_toy_recording, - # test_with_subrecording=test_with_subrecording, - # test_subrecording_n_frames=test_subrecording_n_frames, - # log_to_file=log_to_file, - # source_name=source_name, - # source_data_type=source_data_type, - # source_data_paths=source_data_paths, - # recording_kwargs=recording_kwargs, - # preprocessing_kwargs=preprocessing_kwargs, - # sorter_kwargs=sorter_kwargs, - # postprocessing_kwargs=postprocessing_kwargs, - # curation_kwargs=curation_kwargs, - # visualization_kwargs=visualization_kwargs, - # output_destination=output_destination, - # output_path=output_path, - # ) - + # Just for checking for now - REMOVE LATER print("\nRun at: ", run_at) print("\nRun identifier: ", run_identifier) print("\nRun description: ", run_description) @@ -663,6 +642,28 @@ def main( print("\nOutput destination: ", output_destination) print("\nOutput path: ", output_path) + # Run main function + main( + run_at=run_at, + run_identifier=run_identifier, + run_description=run_description, + test_with_toy_recording=test_with_toy_recording, + test_with_subrecording=test_with_subrecording, + test_subrecording_n_frames=test_subrecording_n_frames, + log_to_file=log_to_file, + source_name=source_name, + source_data_type=source_data_type, + source_data_paths=source_data_paths, + recording_kwargs=recording_kwargs, + preprocessing_kwargs=preprocessing_kwargs, + sorter_kwargs=sorter_kwargs, + postprocessing_kwargs=postprocessing_kwargs, + curation_kwargs=curation_kwargs, + visualization_kwargs=visualization_kwargs, + output_destination=output_destination, + output_path=output_path, + ) + # Known issues: diff --git a/rest/models/sorting.py b/rest/models/sorting.py index 33f71e7..c85549d 100644 --- a/rest/models/sorting.py +++ b/rest/models/sorting.py @@ -28,9 +28,9 @@ class RunKwargs(BaseModel): # Source Data Models # ------------------------------ class SourceName(str, Enum): + local = "local" s3 = "s3" dandi = "dandi" - local = "local" class SourceDataType(str, Enum): nwb = "nwb" @@ -46,9 +46,9 @@ class SourceDataKwargs(BaseModel): # Output Data Models # ------------------------------ class OutputDestination(str, Enum): + local = "local" s3 = "s3" dandi = "dandi" - local = "local" class OutputDataKwargs(BaseModel): output_destination: OutputDestination = Field(..., description="Destination of output data. Choose from: s3, dandi, local.")