diff --git a/README.rst b/README.rst index 953e40176..54a00f4b4 100644 --- a/README.rst +++ b/README.rst @@ -70,6 +70,7 @@ The main dependencies of Frites are : * `Numpy `_ * `Scipy `_ * `MNE Python `_ +* `Neo `_ * `Xarray `_ * `Joblib `_ diff --git a/docs/source/index.rst b/docs/source/index.rst index 4f57de27c..a218e6f5e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -88,8 +88,8 @@ Highlights ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Frites supports inputs from standard libraries like `Numpy `_, - `MNE Python `_ or more recent ones like - labelled `Xarray `_ objects. + `MNE Python `_, `Neo `_ or + more recent ones like labelled `Xarray `_ objects. +++ diff --git a/docs/source/install.rst b/docs/source/install.rst index 4da3b02ab..0b5c2288d 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -13,6 +13,7 @@ The main dependencies of Frites are : * `Numpy `_ * `Scipy `_ * `MNE `_ +* `Neo `_ * `Xarray `_ * `Joblib `_ diff --git a/frites/conn/conn_covgc.py b/frites/conn/conn_covgc.py index 9c7580810..e39087ff8 100644 --- a/frites/conn/conn_covgc.py +++ b/frites/conn/conn_covgc.py @@ -273,6 +273,7 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc', * Standard NumPy arrays of shape (n_epochs, n_roi, n_times) * mne.Epochs + * neo.Block where neo.Segments correspond to epochs * xarray.DataArray of shape (n_epochs, n_roi, n_times) dt : int diff --git a/frites/conn/conn_dfc.py b/frites/conn/conn_dfc.py index 42691d54f..4cf2b7186 100644 --- a/frites/conn/conn_dfc.py +++ b/frites/conn/conn_dfc.py @@ -27,6 +27,7 @@ def conn_dfc(data, win_sample=None, times=None, roi=None, agg_ch=False, * Standard NumPy arrays of shape (n_epochs, n_roi, n_times) * mne.Epochs + * neo.Block where neo.Segments correspond to epochs * xarray.DataArray of shape (n_epochs, n_roi, n_times) win_sample : array_like | None diff --git a/frites/conn/conn_io.py b/frites/conn/conn_io.py index 8712b10df..3d7f90c14 100644 --- a/frites/conn/conn_io.py +++ b/frites/conn/conn_io.py @@ -3,6 +3,11 @@ import pandas as pd import xarray as xr import mne +try: + import neo + HAVE_NEO = True +except ModuleNotFoundError: + HAVE_NEO = False from frites.io import set_log_level, logger from frites.config import CONFIG @@ -22,6 +27,7 @@ def conn_io(data, times=None, roi=None, y=None, sfreq=None, agg_ch=False, * Standard NumPy arrays of shape (n_epochs, n_roi, n_times) * mne.Epochs + * neo.Block where neo.Segments correspond to epochs * xarray.DataArray of shape (n_epochs, n_roi, n_times) times : array_like | None @@ -76,14 +82,26 @@ def conn_io(data, times=None, roi=None, y=None, sfreq=None, agg_ch=False, # ____________________________ DATA CONVERSION ____________________________ # keep xarray attributes and trials + trials, attrs = None, {} if isinstance(data, xr.DataArray): trials, attrs = data[data.dims[0]].data, data.attrs + elif isinstance(data, (mne.EpochsArray, mne.Epochs)): + n_trials = data._data.shape[0] + elif 'neo.io' in str(type(data)): + if not HAVE_NEO: + raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed') + assert isinstance(data, neo.Block) + n_trials = len(data.segments) + # use custom trial ids if provided + if all(['trial_id' in seg.annotations for seg in data.segments]): + trial_ids = ['trial_id' in seg.annotations for seg in data.segments] + trials = np.array(trial_ids, dtype=int) else: - if isinstance(data, (mne.EpochsArray, mne.Epochs)): - n_trials = data._data.shape[0] - else: - n_trials = data.shape[0] - trials, attrs = np.arange(n_trials), {} + n_trials = data.shape[0] + + if trials is None: + trials = np.arange(n_trials) + if y is None: y = trials diff --git a/frites/dataset/__init__.py b/frites/dataset/__init__.py index 4ced5c913..6d56ae330 100644 --- a/frites/dataset/__init__.py +++ b/frites/dataset/__init__.py @@ -2,7 +2,7 @@ This submodule includes containers for the neurophysiological data either for a single-subject or multiple subjects. Several input types are supported -(NumPy, MNE, Xarray). +(NumPy, MNE, Neo, Xarray). """ from .suj_ephy import SubjectEphy # noqa from .ds_ephy import DatasetEphy # noqa diff --git a/frites/dataset/suj_ephy.py b/frites/dataset/suj_ephy.py index 863e7d6c0..826e5e365 100644 --- a/frites/dataset/suj_ephy.py +++ b/frites/dataset/suj_ephy.py @@ -3,6 +3,11 @@ import numpy as np import xarray as xr +try: + import neo + HAVE_NEO = True +except ModuleNotFoundError: + HAVE_NEO = False import frites from frites.config import CONFIG @@ -14,7 +19,7 @@ class SubjectEphy(Attributes): """Single-subject electrophysiological data container. This class can be used to convert the data from different types (e.g - NumPy, MNE-Python, Xarray) into a single format (xarray.DataArray). + NumPy, MNE-Python, Neo, Xarray) into a single format (xarray.DataArray). Parameters ---------- @@ -28,6 +33,7 @@ class SubjectEphy(Attributes): where 'mv' refers to an axis to consider as multi-variate * mne.Epochs or mne.EpochsArray * mne.EpochsTFR (i.e. non-averaged power) + * neo.Block where neo.Segments correspond to Epochs * xarray.DataArray. In that case `y`, `z`, `roi` and `times` inputs can be strings that refer to the coordinate name to use in the DataArray @@ -130,7 +136,7 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True, # get the temporal vector times = x[times].data if isinstance(times, str) else times - if 'mne' in str(type(x)): # mne -> xr + elif 'mne' in str(type(x)): # mne -> xr times = x.times if times is None else times roi = x.info['ch_names'] if roi is None else roi sfreq = x.info['sfreq'] if sfreq is None else sfreq @@ -143,7 +149,31 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True, else: _supp_dim = ('freqs', x.freqs) - if isinstance(x, np.ndarray): # numpy -> xr + elif 'neo.core' in str(type(x)): + if not HAVE_NEO: + raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed') + assert isinstance(x, neo.Block) + + # data integrity checks + # assert common attributes across signals + assert len(np.unique([len(seg.analogsignals) for seg in x.segments]) == 1) + assert len(np.unique([seg.analogsignals[0].units for seg in x.segments]) == 1) + assert len(np.unique([seg.analogsignals[0].sampling_rate for seg in x.segments]) == 1) + assert len(np.unique([seg.analogsignals[0].shape for seg in x.segments]) == 1) + + seg0 = x.segments[0].analogsignals[0] + times = seg0.times.magnitude + sfreq = seg0.sampling_rate.magnitude + + attrs['sfreq_units'] = seg0.sampling_rate.units + attrs['time_units'] = seg0.times.units + attrs['signal_units'] = seg0.units + + data = np.stack([seg.analogsignals[0].magnitude for seg in x.segments]) + # swapping to have time as last dimension + data = data.swapaxes(1, -1) + + elif isinstance(x, np.ndarray): # numpy -> xr data = x if data.ndim == 4: if multivariate: @@ -151,6 +181,11 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True, else: _supp_dim = ('supp', np.arange(data.shape[2])) + try: + data.ndim + except: + print('') + assert data.ndim <= 4, "Data up to 4-dimensions are supported" # ____________________________ Y/Z dtypes _____________________________ diff --git a/frites/dataset/tests/test_suj_ephy.py b/frites/dataset/tests/test_suj_ephy.py index 8715ca5bc..97a300c56 100644 --- a/frites/dataset/tests/test_suj_ephy.py +++ b/frites/dataset/tests/test_suj_ephy.py @@ -1,8 +1,15 @@ """Test SubjectEphy and internal conversions.""" +import pytest import numpy as np import xarray as xr import pandas as pd import mne +try: + import neo + import quantities as pq + HAVE_NEO = True +except ModuleNotFoundError: + HAVE_NEO = False from frites.dataset import SubjectEphy from frites.utils.perf import id as id_arr @@ -47,6 +54,18 @@ def _get_data(dtype, ndim): elif (dtype == 'mne') and (ndim == 4): info = mne.create_info(ch_names, sfreq, ch_types='seeg') x_out = mne.time_frequency.EpochsTFR(info, x_4d, times, freqs) + elif dtype == 'neo': + assert HAVE_NEO, 'Requires Neo to be installed' + data = x_3d if ndim == 3 else x_4d + block = neo.Block() + for epoch_id in range(len(x_3d)): + seg = neo.Segment() + anasig = neo.AnalogSignal(data[epoch_id].T * pq.dimensionless, + t_start=times[0] * pq.s, + sampling_rate=sfreq * pq.Hz) + seg.analogsignals.append(anasig) + block.segments.append(seg) + x_out = block return x_out @@ -117,6 +136,31 @@ def test_mne_inputs(self): da_4d = SubjectEphy(mne_4d, y=y_int, z=z, roi=roi, times=times, **kw) self._test_memory(x_4d, da_4d.data) + @pytest.mark.skipif(not HAVE_NEO, reason="requires Neo") + def test_neo_inputs(self): + """Test function neo_inputs.""" + # ___________________________ test 3d inputs __________________________ + # test inputs + neo_3d = self._get_data('neo', 3) + SubjectEphy(neo_3d, **kw) + SubjectEphy(neo_3d, y=y_int, **kw) + SubjectEphy(neo_3d, z=z, **kw) + SubjectEphy(neo_3d, y=y_int, z=z, roi=roi, **kw) + da_3d = SubjectEphy(neo_3d, y=y_int, z=z, roi=roi, times=times, **kw) + # hstacking neo objects creates a new array instance, data is copied + # self._test_memory(x_3d, da_3d.data) + + # ___________________________ test 4d inputs __________________________ + # test inputs + neo_4d = self._get_data('mne', 4) + SubjectEphy(neo_4d, **kw) + SubjectEphy(neo_4d, y=y_int, **kw) + SubjectEphy(neo_4d, z=z, **kw) + SubjectEphy(neo_4d, y=y_int, z=z, roi=roi, **kw) + da_4d = SubjectEphy(neo_4d, y=y_int, z=z, roi=roi, times=times, **kw) + # hstacking neo objects creates a new array instance, data is copied + # self._test_memory(x_4d, da_4d.data) + def test_coordinates(self): """Test if coordinates and dims are properly set""" # _________________________ Test Xarray coords ________________________ diff --git a/frites/simulations/sim_generate_data.py b/frites/simulations/sim_generate_data.py index 462412cc9..690c6c721 100644 --- a/frites/simulations/sim_generate_data.py +++ b/frites/simulations/sim_generate_data.py @@ -4,6 +4,13 @@ from scipy.signal import savgol_filter from itertools import product +try: + import neo + import quantities as pq + HAVE_NEO = True +except ModuleNotFoundError: + HAVE_NEO = False + MA_NAMES = ['L_VCcm', 'L_VCl', 'L_VCs', 'L_Cu', 'L_VCrm', 'L_ITCm', 'L_ITCr', 'L_MTCc', 'L_STCc', 'L_STCr', 'L_MTCr', 'L_ICC', 'L_IPCv', 'L_IPCd', 'L_SPC', 'L_SPCm', 'L_PCm', 'L_PCC', 'L_Sv', 'L_Sdl', @@ -23,7 +30,8 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1, n_sites_per_roi=1, n_epochs=100, n_sines=100, f_min=.5, - f_max=160., noise=10, as_mne=False, random_state=None): + f_max=160., noise=10, as_mne=False, as_neo=False, + random_state=None): """Simulate electrophysiological data of a single subject. This function generate some illustrative random electrophysiological data @@ -54,6 +62,8 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1, Noise level. as_mne : bool | False If True, data are converted to a mne.EpochsArray structure + as_neo : bool | False + If True, data are converted to a neo.Block structure random_state : int | None Fix the random state for the reproducibility. @@ -103,6 +113,18 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1, from mne import create_info, EpochsArray info = create_info(roi.tolist(), sf, ch_types='seeg') signal = EpochsArray(signal, info, tmin=float(time[0]), verbose=False) + if as_neo: + if not HAVE_NEO: + raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed') + # building a neo structure with one segment per frites 'epoch' + block = neo.Block() + for epoch_idx in range(signal.shape[0]): + sig = neo.AnalogSignal(signal[epoch_idx].swapaxes(0, -1)*pq.dimensionless, + t_start=time[0] * pq.s, sampling_rate=sf * pq.Hz) + seg = neo.Segment(trial_id=epoch_idx) + seg.analogsignals.append(sig) + block.segments.append(seg) + signal = block return signal, roi, time.squeeze() diff --git a/frites/simulations/sim_mi.py b/frites/simulations/sim_mi.py index a3f54e747..aa379ff61 100644 --- a/frites/simulations/sim_mi.py +++ b/frites/simulations/sim_mi.py @@ -75,7 +75,19 @@ def sim_mi_cc(x, snr=.9): # if mne types, turn into arrays if isinstance(x[0], CONFIG["MNE_EPOCHS_TYPE"]): x = [x[k].get_data() for k in range(len(x))] + elif 'neo.core' in str(type(x[0])): + pass + # TODO: To be discussed also for other functions in this module + # Why not use suj_ephy class here? + # subject_list = [] + # for block in x: + # subject_data = np.stack([seg.analogsignals[0].magnitude for seg in block.segments]) + # # reorder dimensions to match (n_epochs, n_channels, n_times) + # subject_list.append(subject_data.swapaxes(1, 2)) + # x = subject_list + n_times = x[0].shape[-1] + # cluster definition (20% length around central point) cluster = _get_cluster(n_times, location='center', perc=.2) # ground truth definition diff --git a/frites/simulations/tests/test_sim_generate_data.py b/frites/simulations/tests/test_sim_generate_data.py index 674987568..5536bf28c 100644 --- a/frites/simulations/tests/test_sim_generate_data.py +++ b/frites/simulations/tests/test_sim_generate_data.py @@ -2,6 +2,12 @@ import numpy as np from mne import EpochsArray +try: + import neo + HAVE_NEO = True +except ModuleNotFoundError: + HAVE_NEO = False + from frites.simulations import (sim_single_suj_ephy, sim_multi_suj_ephy) @@ -19,6 +25,10 @@ def test_sim_single_suj_ephy(self): # mne type data, _, _ = sim_single_suj_ephy(as_mne=True) assert isinstance(data, EpochsArray) + # neo type + if HAVE_NEO: + data, _, _ = sim_single_suj_ephy(as_neo=True) + assert isinstance(data, neo.core.Block) def test_sim_multi_suj_ephy(self): """Test function sim_multi_suj_ephy.""" @@ -34,3 +44,8 @@ def test_sim_multi_suj_ephy(self): # mne type data, _, _ = sim_multi_suj_ephy(n_subjects=5, as_mne=True) assert all([isinstance(k, EpochsArray) for k in data]) + # neo type + if HAVE_NEO: + data, _, _ = sim_multi_suj_ephy(n_subjects=5, as_neo=True) + assert all([isinstance(k, neo.Block) for k in data]) + diff --git a/frites/simulations/tests/test_sim_mi.py b/frites/simulations/tests/test_sim_mi.py index 1daf755ab..1ff41e3bd 100644 --- a/frites/simulations/tests/test_sim_mi.py +++ b/frites/simulations/tests/test_sim_mi.py @@ -10,11 +10,12 @@ n_roi = 10 n_sites_per_roi = 1 as_mne = False +as_neo = False x, roi, time = sim_multi_suj_ephy(n_subjects=n_subjects, n_epochs=n_epochs, n_times=n_times, n_roi=n_roi, n_sites_per_roi=n_sites_per_roi, - as_mne=as_mne, modality=modality, - random_state=1) + as_mne=as_mne, as_neo=as_neo, + modality=modality, random_state=1) class TestSimMi(object): # noqa diff --git a/setup.py b/setup.py index df496ede8..0d67ed368 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,7 @@ def read(fname): with open('requirements.txt') as f: requirements = f.read().splitlines() -core_deps = ['matplotlib', 'networkx', 'numba', 'dcor', 'scikit-learn'] +core_deps = ['matplotlib', 'networkx', 'numba', 'dcor', 'scikit-learn', 'neo'] test_deps = ['pytest', 'pytest-sugar', 'pytest-cov', 'codecov'] doc_deps = [ 'sphinx!=4.1.0', 'sphinx-gallery', 'pydata-sphinx-theme>=0.6.3',