diff --git a/README.rst b/README.rst
index 953e40176..54a00f4b4 100644
--- a/README.rst
+++ b/README.rst
@@ -70,6 +70,7 @@ The main dependencies of Frites are :
* `Numpy `_
* `Scipy `_
* `MNE Python `_
+* `Neo `_
* `Xarray `_
* `Joblib `_
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 4f57de27c..a218e6f5e 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -88,8 +88,8 @@ Highlights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Frites supports inputs from standard libraries like `Numpy `_,
- `MNE Python `_ or more recent ones like
- labelled `Xarray `_ objects.
+ `MNE Python `_, `Neo `_ or
+ more recent ones like labelled `Xarray `_ objects.
+++
diff --git a/docs/source/install.rst b/docs/source/install.rst
index 4da3b02ab..0b5c2288d 100644
--- a/docs/source/install.rst
+++ b/docs/source/install.rst
@@ -13,6 +13,7 @@ The main dependencies of Frites are :
* `Numpy `_
* `Scipy `_
* `MNE `_
+* `Neo `_
* `Xarray `_
* `Joblib `_
diff --git a/frites/conn/conn_covgc.py b/frites/conn/conn_covgc.py
index 9c7580810..e39087ff8 100644
--- a/frites/conn/conn_covgc.py
+++ b/frites/conn/conn_covgc.py
@@ -273,6 +273,7 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',
* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
+ * neo.Block where neo.Segments correspond to epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)
dt : int
diff --git a/frites/conn/conn_dfc.py b/frites/conn/conn_dfc.py
index 42691d54f..4cf2b7186 100644
--- a/frites/conn/conn_dfc.py
+++ b/frites/conn/conn_dfc.py
@@ -27,6 +27,7 @@ def conn_dfc(data, win_sample=None, times=None, roi=None, agg_ch=False,
* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
+ * neo.Block where neo.Segments correspond to epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)
win_sample : array_like | None
diff --git a/frites/conn/conn_io.py b/frites/conn/conn_io.py
index 8712b10df..3d7f90c14 100644
--- a/frites/conn/conn_io.py
+++ b/frites/conn/conn_io.py
@@ -3,6 +3,11 @@
import pandas as pd
import xarray as xr
import mne
+try:
+ import neo
+ HAVE_NEO = True
+except ModuleNotFoundError:
+ HAVE_NEO = False
from frites.io import set_log_level, logger
from frites.config import CONFIG
@@ -22,6 +27,7 @@ def conn_io(data, times=None, roi=None, y=None, sfreq=None, agg_ch=False,
* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
+ * neo.Block where neo.Segments correspond to epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)
times : array_like | None
@@ -76,14 +82,26 @@ def conn_io(data, times=None, roi=None, y=None, sfreq=None, agg_ch=False,
# ____________________________ DATA CONVERSION ____________________________
# keep xarray attributes and trials
+ trials, attrs = None, {}
if isinstance(data, xr.DataArray):
trials, attrs = data[data.dims[0]].data, data.attrs
+ elif isinstance(data, (mne.EpochsArray, mne.Epochs)):
+ n_trials = data._data.shape[0]
+ elif 'neo.io' in str(type(data)):
+ if not HAVE_NEO:
+ raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed')
+ assert isinstance(data, neo.Block)
+ n_trials = len(data.segments)
+ # use custom trial ids if provided
+ if all(['trial_id' in seg.annotations for seg in data.segments]):
+ trial_ids = ['trial_id' in seg.annotations for seg in data.segments]
+ trials = np.array(trial_ids, dtype=int)
else:
- if isinstance(data, (mne.EpochsArray, mne.Epochs)):
- n_trials = data._data.shape[0]
- else:
- n_trials = data.shape[0]
- trials, attrs = np.arange(n_trials), {}
+ n_trials = data.shape[0]
+
+ if trials is None:
+ trials = np.arange(n_trials)
+
if y is None:
y = trials
diff --git a/frites/dataset/__init__.py b/frites/dataset/__init__.py
index 4ced5c913..6d56ae330 100644
--- a/frites/dataset/__init__.py
+++ b/frites/dataset/__init__.py
@@ -2,7 +2,7 @@
This submodule includes containers for the neurophysiological data either for
a single-subject or multiple subjects. Several input types are supported
-(NumPy, MNE, Xarray).
+(NumPy, MNE, Neo, Xarray).
"""
from .suj_ephy import SubjectEphy # noqa
from .ds_ephy import DatasetEphy # noqa
diff --git a/frites/dataset/suj_ephy.py b/frites/dataset/suj_ephy.py
index 863e7d6c0..826e5e365 100644
--- a/frites/dataset/suj_ephy.py
+++ b/frites/dataset/suj_ephy.py
@@ -3,6 +3,11 @@
import numpy as np
import xarray as xr
+try:
+ import neo
+ HAVE_NEO = True
+except ModuleNotFoundError:
+ HAVE_NEO = False
import frites
from frites.config import CONFIG
@@ -14,7 +19,7 @@ class SubjectEphy(Attributes):
"""Single-subject electrophysiological data container.
This class can be used to convert the data from different types (e.g
- NumPy, MNE-Python, Xarray) into a single format (xarray.DataArray).
+ NumPy, MNE-Python, Neo, Xarray) into a single format (xarray.DataArray).
Parameters
----------
@@ -28,6 +33,7 @@ class SubjectEphy(Attributes):
where 'mv' refers to an axis to consider as multi-variate
* mne.Epochs or mne.EpochsArray
* mne.EpochsTFR (i.e. non-averaged power)
+ * neo.Block where neo.Segments correspond to Epochs
* xarray.DataArray. In that case `y`, `z`, `roi` and `times` inputs
can be strings that refer to the coordinate name to use in the
DataArray
@@ -130,7 +136,7 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True,
# get the temporal vector
times = x[times].data if isinstance(times, str) else times
- if 'mne' in str(type(x)): # mne -> xr
+ elif 'mne' in str(type(x)): # mne -> xr
times = x.times if times is None else times
roi = x.info['ch_names'] if roi is None else roi
sfreq = x.info['sfreq'] if sfreq is None else sfreq
@@ -143,7 +149,31 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True,
else:
_supp_dim = ('freqs', x.freqs)
- if isinstance(x, np.ndarray): # numpy -> xr
+ elif 'neo.core' in str(type(x)):
+ if not HAVE_NEO:
+ raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed')
+ assert isinstance(x, neo.Block)
+
+ # data integrity checks
+ # assert common attributes across signals
+ assert len(np.unique([len(seg.analogsignals) for seg in x.segments]) == 1)
+ assert len(np.unique([seg.analogsignals[0].units for seg in x.segments]) == 1)
+ assert len(np.unique([seg.analogsignals[0].sampling_rate for seg in x.segments]) == 1)
+ assert len(np.unique([seg.analogsignals[0].shape for seg in x.segments]) == 1)
+
+ seg0 = x.segments[0].analogsignals[0]
+ times = seg0.times.magnitude
+ sfreq = seg0.sampling_rate.magnitude
+
+ attrs['sfreq_units'] = seg0.sampling_rate.units
+ attrs['time_units'] = seg0.times.units
+ attrs['signal_units'] = seg0.units
+
+ data = np.stack([seg.analogsignals[0].magnitude for seg in x.segments])
+ # swapping to have time as last dimension
+ data = data.swapaxes(1, -1)
+
+ elif isinstance(x, np.ndarray): # numpy -> xr
data = x
if data.ndim == 4:
if multivariate:
@@ -151,6 +181,11 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True,
else:
_supp_dim = ('supp', np.arange(data.shape[2]))
+ try:
+ data.ndim
+ except:
+ print('')
+
assert data.ndim <= 4, "Data up to 4-dimensions are supported"
# ____________________________ Y/Z dtypes _____________________________
diff --git a/frites/dataset/tests/test_suj_ephy.py b/frites/dataset/tests/test_suj_ephy.py
index 8715ca5bc..97a300c56 100644
--- a/frites/dataset/tests/test_suj_ephy.py
+++ b/frites/dataset/tests/test_suj_ephy.py
@@ -1,8 +1,15 @@
"""Test SubjectEphy and internal conversions."""
+import pytest
import numpy as np
import xarray as xr
import pandas as pd
import mne
+try:
+ import neo
+ import quantities as pq
+ HAVE_NEO = True
+except ModuleNotFoundError:
+ HAVE_NEO = False
from frites.dataset import SubjectEphy
from frites.utils.perf import id as id_arr
@@ -47,6 +54,18 @@ def _get_data(dtype, ndim):
elif (dtype == 'mne') and (ndim == 4):
info = mne.create_info(ch_names, sfreq, ch_types='seeg')
x_out = mne.time_frequency.EpochsTFR(info, x_4d, times, freqs)
+ elif dtype == 'neo':
+ assert HAVE_NEO, 'Requires Neo to be installed'
+ data = x_3d if ndim == 3 else x_4d
+ block = neo.Block()
+ for epoch_id in range(len(x_3d)):
+ seg = neo.Segment()
+ anasig = neo.AnalogSignal(data[epoch_id].T * pq.dimensionless,
+ t_start=times[0] * pq.s,
+ sampling_rate=sfreq * pq.Hz)
+ seg.analogsignals.append(anasig)
+ block.segments.append(seg)
+ x_out = block
return x_out
@@ -117,6 +136,31 @@ def test_mne_inputs(self):
da_4d = SubjectEphy(mne_4d, y=y_int, z=z, roi=roi, times=times, **kw)
self._test_memory(x_4d, da_4d.data)
+ @pytest.mark.skipif(not HAVE_NEO, reason="requires Neo")
+ def test_neo_inputs(self):
+ """Test function neo_inputs."""
+ # ___________________________ test 3d inputs __________________________
+ # test inputs
+ neo_3d = self._get_data('neo', 3)
+ SubjectEphy(neo_3d, **kw)
+ SubjectEphy(neo_3d, y=y_int, **kw)
+ SubjectEphy(neo_3d, z=z, **kw)
+ SubjectEphy(neo_3d, y=y_int, z=z, roi=roi, **kw)
+ da_3d = SubjectEphy(neo_3d, y=y_int, z=z, roi=roi, times=times, **kw)
+ # hstacking neo objects creates a new array instance, data is copied
+ # self._test_memory(x_3d, da_3d.data)
+
+ # ___________________________ test 4d inputs __________________________
+ # test inputs
+ neo_4d = self._get_data('mne', 4)
+ SubjectEphy(neo_4d, **kw)
+ SubjectEphy(neo_4d, y=y_int, **kw)
+ SubjectEphy(neo_4d, z=z, **kw)
+ SubjectEphy(neo_4d, y=y_int, z=z, roi=roi, **kw)
+ da_4d = SubjectEphy(neo_4d, y=y_int, z=z, roi=roi, times=times, **kw)
+ # hstacking neo objects creates a new array instance, data is copied
+ # self._test_memory(x_4d, da_4d.data)
+
def test_coordinates(self):
"""Test if coordinates and dims are properly set"""
# _________________________ Test Xarray coords ________________________
diff --git a/frites/simulations/sim_generate_data.py b/frites/simulations/sim_generate_data.py
index 462412cc9..690c6c721 100644
--- a/frites/simulations/sim_generate_data.py
+++ b/frites/simulations/sim_generate_data.py
@@ -4,6 +4,13 @@
from scipy.signal import savgol_filter
from itertools import product
+try:
+ import neo
+ import quantities as pq
+ HAVE_NEO = True
+except ModuleNotFoundError:
+ HAVE_NEO = False
+
MA_NAMES = ['L_VCcm', 'L_VCl', 'L_VCs', 'L_Cu', 'L_VCrm', 'L_ITCm', 'L_ITCr',
'L_MTCc', 'L_STCc', 'L_STCr', 'L_MTCr', 'L_ICC', 'L_IPCv',
'L_IPCd', 'L_SPC', 'L_SPCm', 'L_PCm', 'L_PCC', 'L_Sv', 'L_Sdl',
@@ -23,7 +30,8 @@
def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1,
n_sites_per_roi=1, n_epochs=100, n_sines=100, f_min=.5,
- f_max=160., noise=10, as_mne=False, random_state=None):
+ f_max=160., noise=10, as_mne=False, as_neo=False,
+ random_state=None):
"""Simulate electrophysiological data of a single subject.
This function generate some illustrative random electrophysiological data
@@ -54,6 +62,8 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1,
Noise level.
as_mne : bool | False
If True, data are converted to a mne.EpochsArray structure
+ as_neo : bool | False
+ If True, data are converted to a neo.Block structure
random_state : int | None
Fix the random state for the reproducibility.
@@ -103,6 +113,18 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1,
from mne import create_info, EpochsArray
info = create_info(roi.tolist(), sf, ch_types='seeg')
signal = EpochsArray(signal, info, tmin=float(time[0]), verbose=False)
+ if as_neo:
+ if not HAVE_NEO:
+ raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed')
+ # building a neo structure with one segment per frites 'epoch'
+ block = neo.Block()
+ for epoch_idx in range(signal.shape[0]):
+ sig = neo.AnalogSignal(signal[epoch_idx].swapaxes(0, -1)*pq.dimensionless,
+ t_start=time[0] * pq.s, sampling_rate=sf * pq.Hz)
+ seg = neo.Segment(trial_id=epoch_idx)
+ seg.analogsignals.append(sig)
+ block.segments.append(seg)
+ signal = block
return signal, roi, time.squeeze()
diff --git a/frites/simulations/sim_mi.py b/frites/simulations/sim_mi.py
index a3f54e747..aa379ff61 100644
--- a/frites/simulations/sim_mi.py
+++ b/frites/simulations/sim_mi.py
@@ -75,7 +75,19 @@ def sim_mi_cc(x, snr=.9):
# if mne types, turn into arrays
if isinstance(x[0], CONFIG["MNE_EPOCHS_TYPE"]):
x = [x[k].get_data() for k in range(len(x))]
+ elif 'neo.core' in str(type(x[0])):
+ pass
+ # TODO: To be discussed also for other functions in this module
+ # Why not use suj_ephy class here?
+ # subject_list = []
+ # for block in x:
+ # subject_data = np.stack([seg.analogsignals[0].magnitude for seg in block.segments])
+ # # reorder dimensions to match (n_epochs, n_channels, n_times)
+ # subject_list.append(subject_data.swapaxes(1, 2))
+ # x = subject_list
+
n_times = x[0].shape[-1]
+
# cluster definition (20% length around central point)
cluster = _get_cluster(n_times, location='center', perc=.2)
# ground truth definition
diff --git a/frites/simulations/tests/test_sim_generate_data.py b/frites/simulations/tests/test_sim_generate_data.py
index 674987568..5536bf28c 100644
--- a/frites/simulations/tests/test_sim_generate_data.py
+++ b/frites/simulations/tests/test_sim_generate_data.py
@@ -2,6 +2,12 @@
import numpy as np
from mne import EpochsArray
+try:
+ import neo
+ HAVE_NEO = True
+except ModuleNotFoundError:
+ HAVE_NEO = False
+
from frites.simulations import (sim_single_suj_ephy, sim_multi_suj_ephy)
@@ -19,6 +25,10 @@ def test_sim_single_suj_ephy(self):
# mne type
data, _, _ = sim_single_suj_ephy(as_mne=True)
assert isinstance(data, EpochsArray)
+ # neo type
+ if HAVE_NEO:
+ data, _, _ = sim_single_suj_ephy(as_neo=True)
+ assert isinstance(data, neo.core.Block)
def test_sim_multi_suj_ephy(self):
"""Test function sim_multi_suj_ephy."""
@@ -34,3 +44,8 @@ def test_sim_multi_suj_ephy(self):
# mne type
data, _, _ = sim_multi_suj_ephy(n_subjects=5, as_mne=True)
assert all([isinstance(k, EpochsArray) for k in data])
+ # neo type
+ if HAVE_NEO:
+ data, _, _ = sim_multi_suj_ephy(n_subjects=5, as_neo=True)
+ assert all([isinstance(k, neo.Block) for k in data])
+
diff --git a/frites/simulations/tests/test_sim_mi.py b/frites/simulations/tests/test_sim_mi.py
index 1daf755ab..1ff41e3bd 100644
--- a/frites/simulations/tests/test_sim_mi.py
+++ b/frites/simulations/tests/test_sim_mi.py
@@ -10,11 +10,12 @@
n_roi = 10
n_sites_per_roi = 1
as_mne = False
+as_neo = False
x, roi, time = sim_multi_suj_ephy(n_subjects=n_subjects, n_epochs=n_epochs,
n_times=n_times, n_roi=n_roi,
n_sites_per_roi=n_sites_per_roi,
- as_mne=as_mne, modality=modality,
- random_state=1)
+ as_mne=as_mne, as_neo=as_neo,
+ modality=modality, random_state=1)
class TestSimMi(object): # noqa
diff --git a/setup.py b/setup.py
index df496ede8..0d67ed368 100644
--- a/setup.py
+++ b/setup.py
@@ -27,7 +27,7 @@ def read(fname):
with open('requirements.txt') as f:
requirements = f.read().splitlines()
-core_deps = ['matplotlib', 'networkx', 'numba', 'dcor', 'scikit-learn']
+core_deps = ['matplotlib', 'networkx', 'numba', 'dcor', 'scikit-learn', 'neo']
test_deps = ['pytest', 'pytest-sugar', 'pytest-cov', 'codecov']
doc_deps = [
'sphinx!=4.1.0', 'sphinx-gallery', 'pydata-sphinx-theme>=0.6.3',