Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add neo objects as data sources #14

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ The main dependencies of Frites are :
* `Numpy <https://numpy.org/>`_
* `Scipy <https://www.scipy.org/>`_
* `MNE Python <https://mne.tools/stable/index.html>`_
* `Neo <https://pypi.org/project/neo/>`_
* `Xarray <http://xarray.pydata.org/en/stable/>`_
* `Joblib <https://joblib.readthedocs.io/en/latest/>`_

Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ Highlights
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Frites supports inputs from standard libraries like `Numpy <https://numpy.org/>`_,
`MNE Python <https://mne.tools/stable/index.html>`_ or more recent ones like
labelled `Xarray <http://xarray.pydata.org/en/stable/>`_ objects.
`MNE Python <https://mne.tools/stable/index.html>`_, `Neo <https://pypi.org/project/neo/>`_ or
more recent ones like labelled `Xarray <http://xarray.pydata.org/en/stable/>`_ objects.

+++

Expand Down
1 change: 1 addition & 0 deletions docs/source/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ The main dependencies of Frites are :
* `Numpy <https://numpy.org/>`_
* `Scipy <https://www.scipy.org/>`_
* `MNE <https://mne.tools/stable/index.html>`_
* `Neo <https://pypi.org/project/neo>`_
* `Xarray <http://xarray.pydata.org/en/stable/>`_
* `Joblib <https://joblib.readthedocs.io/en/latest/>`_

Expand Down
1 change: 1 addition & 0 deletions frites/conn/conn_covgc.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ def conn_covgc(data, dt, lag, t0, step=1, roi=None, times=None, method='gc',

* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
* neo.Block where neo.Segments correspond to epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)

dt : int
Expand Down
1 change: 1 addition & 0 deletions frites/conn/conn_dfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def conn_dfc(data, win_sample=None, times=None, roi=None, agg_ch=False,

* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
* neo.Block where neo.Segments correspond to epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)

win_sample : array_like | None
Expand Down
28 changes: 23 additions & 5 deletions frites/conn/conn_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import pandas as pd
import xarray as xr
import mne
try:
import neo
HAVE_NEO = True
except ModuleNotFoundError:
HAVE_NEO = False

from frites.io import set_log_level, logger
from frites.config import CONFIG
Expand All @@ -22,6 +27,7 @@ def conn_io(data, times=None, roi=None, y=None, sfreq=None, agg_ch=False,

* Standard NumPy arrays of shape (n_epochs, n_roi, n_times)
* mne.Epochs
* neo.Block where neo.Segments correspond to epochs
* xarray.DataArray of shape (n_epochs, n_roi, n_times)

times : array_like | None
Expand Down Expand Up @@ -76,14 +82,26 @@ def conn_io(data, times=None, roi=None, y=None, sfreq=None, agg_ch=False,

# ____________________________ DATA CONVERSION ____________________________
# keep xarray attributes and trials
trials, attrs = None, {}
if isinstance(data, xr.DataArray):
trials, attrs = data[data.dims[0]].data, data.attrs
elif isinstance(data, (mne.EpochsArray, mne.Epochs)):
n_trials = data._data.shape[0]
elif 'neo.io' in str(type(data)):
if not HAVE_NEO:
raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed')
assert isinstance(data, neo.Block)
n_trials = len(data.segments)
# use custom trial ids if provided
if all(['trial_id' in seg.annotations for seg in data.segments]):
trial_ids = ['trial_id' in seg.annotations for seg in data.segments]
trials = np.array(trial_ids, dtype=int)
else:
if isinstance(data, (mne.EpochsArray, mne.Epochs)):
n_trials = data._data.shape[0]
else:
n_trials = data.shape[0]
trials, attrs = np.arange(n_trials), {}
n_trials = data.shape[0]

if trials is None:
trials = np.arange(n_trials)

if y is None:
y = trials

Expand Down
2 changes: 1 addition & 1 deletion frites/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

This submodule includes containers for the neurophysiological data either for
a single-subject or multiple subjects. Several input types are supported
(NumPy, MNE, Xarray).
(NumPy, MNE, Neo, Xarray).
"""
from .suj_ephy import SubjectEphy # noqa
from .ds_ephy import DatasetEphy # noqa
41 changes: 38 additions & 3 deletions frites/dataset/suj_ephy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import numpy as np
import xarray as xr
try:
import neo
HAVE_NEO = True
except ModuleNotFoundError:
HAVE_NEO = False

import frites
from frites.config import CONFIG
Expand All @@ -14,7 +19,7 @@ class SubjectEphy(Attributes):
"""Single-subject electrophysiological data container.

This class can be used to convert the data from different types (e.g
NumPy, MNE-Python, Xarray) into a single format (xarray.DataArray).
NumPy, MNE-Python, Neo, Xarray) into a single format (xarray.DataArray).

Parameters
----------
Expand All @@ -28,6 +33,7 @@ class SubjectEphy(Attributes):
where 'mv' refers to an axis to consider as multi-variate
* mne.Epochs or mne.EpochsArray
* mne.EpochsTFR (i.e. non-averaged power)
* neo.Block where neo.Segments correspond to Epochs
* xarray.DataArray. In that case `y`, `z`, `roi` and `times` inputs
can be strings that refer to the coordinate name to use in the
DataArray
Expand Down Expand Up @@ -130,7 +136,7 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True,
# get the temporal vector
times = x[times].data if isinstance(times, str) else times

if 'mne' in str(type(x)): # mne -> xr
elif 'mne' in str(type(x)): # mne -> xr
times = x.times if times is None else times
roi = x.info['ch_names'] if roi is None else roi
sfreq = x.info['sfreq'] if sfreq is None else sfreq
Expand All @@ -143,14 +149,43 @@ def __new__(self, x, y=None, z=None, roi=None, times=None, agg_ch=True,
else:
_supp_dim = ('freqs', x.freqs)

if isinstance(x, np.ndarray): # numpy -> xr
elif 'neo.core' in str(type(x)):
if not HAVE_NEO:
raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed')
assert isinstance(x, neo.Block)

# data integrity checks
# assert common attributes across signals
assert len(np.unique([len(seg.analogsignals) for seg in x.segments]) == 1)
assert len(np.unique([seg.analogsignals[0].units for seg in x.segments]) == 1)
assert len(np.unique([seg.analogsignals[0].sampling_rate for seg in x.segments]) == 1)
assert len(np.unique([seg.analogsignals[0].shape for seg in x.segments]) == 1)

seg0 = x.segments[0].analogsignals[0]
times = seg0.times.magnitude
sfreq = seg0.sampling_rate.magnitude

attrs['sfreq_units'] = seg0.sampling_rate.units
attrs['time_units'] = seg0.times.units
attrs['signal_units'] = seg0.units

data = np.stack([seg.analogsignals[0].magnitude for seg in x.segments])
# swapping to have time as last dimension
data = data.swapaxes(1, -1)

elif isinstance(x, np.ndarray): # numpy -> xr
data = x
if data.ndim == 4:
if multivariate:
_supp_dim = ('mv', np.full((data.shape[2]), np.nan))
else:
_supp_dim = ('supp', np.arange(data.shape[2]))

try:
data.ndim
except:
print('')

assert data.ndim <= 4, "Data up to 4-dimensions are supported"

# ____________________________ Y/Z dtypes _____________________________
Expand Down
44 changes: 44 additions & 0 deletions frites/dataset/tests/test_suj_ephy.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
"""Test SubjectEphy and internal conversions."""
import pytest
import numpy as np
import xarray as xr
import pandas as pd
import mne
try:
import neo
import quantities as pq
HAVE_NEO = True
except ModuleNotFoundError:
HAVE_NEO = False

from frites.dataset import SubjectEphy
from frites.utils.perf import id as id_arr
Expand Down Expand Up @@ -47,6 +54,18 @@ def _get_data(dtype, ndim):
elif (dtype == 'mne') and (ndim == 4):
info = mne.create_info(ch_names, sfreq, ch_types='seeg')
x_out = mne.time_frequency.EpochsTFR(info, x_4d, times, freqs)
elif dtype == 'neo':
assert HAVE_NEO, 'Requires Neo to be installed'
data = x_3d if ndim == 3 else x_4d
block = neo.Block()
for epoch_id in range(len(x_3d)):
seg = neo.Segment()
anasig = neo.AnalogSignal(data[epoch_id].T * pq.dimensionless,
t_start=times[0] * pq.s,
sampling_rate=sfreq * pq.Hz)
seg.analogsignals.append(anasig)
block.segments.append(seg)
x_out = block

return x_out

Expand Down Expand Up @@ -117,6 +136,31 @@ def test_mne_inputs(self):
da_4d = SubjectEphy(mne_4d, y=y_int, z=z, roi=roi, times=times, **kw)
self._test_memory(x_4d, da_4d.data)

@pytest.mark.skipif(not HAVE_NEO, reason="requires Neo")
def test_neo_inputs(self):
"""Test function neo_inputs."""
# ___________________________ test 3d inputs __________________________
# test inputs
neo_3d = self._get_data('neo', 3)
SubjectEphy(neo_3d, **kw)
SubjectEphy(neo_3d, y=y_int, **kw)
SubjectEphy(neo_3d, z=z, **kw)
SubjectEphy(neo_3d, y=y_int, z=z, roi=roi, **kw)
da_3d = SubjectEphy(neo_3d, y=y_int, z=z, roi=roi, times=times, **kw)
# hstacking neo objects creates a new array instance, data is copied
# self._test_memory(x_3d, da_3d.data)

# ___________________________ test 4d inputs __________________________
# test inputs
neo_4d = self._get_data('mne', 4)
SubjectEphy(neo_4d, **kw)
SubjectEphy(neo_4d, y=y_int, **kw)
SubjectEphy(neo_4d, z=z, **kw)
SubjectEphy(neo_4d, y=y_int, z=z, roi=roi, **kw)
da_4d = SubjectEphy(neo_4d, y=y_int, z=z, roi=roi, times=times, **kw)
# hstacking neo objects creates a new array instance, data is copied
# self._test_memory(x_4d, da_4d.data)

def test_coordinates(self):
"""Test if coordinates and dims are properly set"""
# _________________________ Test Xarray coords ________________________
Expand Down
24 changes: 23 additions & 1 deletion frites/simulations/sim_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@
from scipy.signal import savgol_filter
from itertools import product

try:
import neo
import quantities as pq
HAVE_NEO = True
except ModuleNotFoundError:
HAVE_NEO = False

MA_NAMES = ['L_VCcm', 'L_VCl', 'L_VCs', 'L_Cu', 'L_VCrm', 'L_ITCm', 'L_ITCr',
'L_MTCc', 'L_STCc', 'L_STCr', 'L_MTCr', 'L_ICC', 'L_IPCv',
'L_IPCd', 'L_SPC', 'L_SPCm', 'L_PCm', 'L_PCC', 'L_Sv', 'L_Sdl',
Expand All @@ -23,7 +30,8 @@

def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1,
n_sites_per_roi=1, n_epochs=100, n_sines=100, f_min=.5,
f_max=160., noise=10, as_mne=False, random_state=None):
f_max=160., noise=10, as_mne=False, as_neo=False,
random_state=None):
"""Simulate electrophysiological data of a single subject.

This function generate some illustrative random electrophysiological data
Expand Down Expand Up @@ -54,6 +62,8 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1,
Noise level.
as_mne : bool | False
If True, data are converted to a mne.EpochsArray structure
as_neo : bool | False
If True, data are converted to a neo.Block structure
random_state : int | None
Fix the random state for the reproducibility.

Expand Down Expand Up @@ -103,6 +113,18 @@ def sim_single_suj_ephy(modality="meeg", sf=512., n_times=1000, n_roi=1,
from mne import create_info, EpochsArray
info = create_info(roi.tolist(), sf, ch_types='seeg')
signal = EpochsArray(signal, info, tmin=float(time[0]), verbose=False)
if as_neo:
if not HAVE_NEO:
raise ModuleNotFoundError('Loading Neo objects requires Neo to be installed')
# building a neo structure with one segment per frites 'epoch'
block = neo.Block()
for epoch_idx in range(signal.shape[0]):
sig = neo.AnalogSignal(signal[epoch_idx].swapaxes(0, -1)*pq.dimensionless,
t_start=time[0] * pq.s, sampling_rate=sf * pq.Hz)
seg = neo.Segment(trial_id=epoch_idx)
seg.analogsignals.append(sig)
block.segments.append(seg)
signal = block
return signal, roi, time.squeeze()


Expand Down
12 changes: 12 additions & 0 deletions frites/simulations/sim_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,19 @@ def sim_mi_cc(x, snr=.9):
# if mne types, turn into arrays
if isinstance(x[0], CONFIG["MNE_EPOCHS_TYPE"]):
x = [x[k].get_data() for k in range(len(x))]
elif 'neo.core' in str(type(x[0])):
pass
# TODO: To be discussed also for other functions in this module
# Why not use suj_ephy class here?
# subject_list = []
# for block in x:
# subject_data = np.stack([seg.analogsignals[0].magnitude for seg in block.segments])
# # reorder dimensions to match (n_epochs, n_channels, n_times)
# subject_list.append(subject_data.swapaxes(1, 2))
# x = subject_list

n_times = x[0].shape[-1]

# cluster definition (20% length around central point)
cluster = _get_cluster(n_times, location='center', perc=.2)
# ground truth definition
Expand Down
15 changes: 15 additions & 0 deletions frites/simulations/tests/test_sim_generate_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@
import numpy as np
from mne import EpochsArray

try:
import neo
HAVE_NEO = True
except ModuleNotFoundError:
HAVE_NEO = False

from frites.simulations import (sim_single_suj_ephy, sim_multi_suj_ephy)


Expand All @@ -19,6 +25,10 @@ def test_sim_single_suj_ephy(self):
# mne type
data, _, _ = sim_single_suj_ephy(as_mne=True)
assert isinstance(data, EpochsArray)
# neo type
if HAVE_NEO:
data, _, _ = sim_single_suj_ephy(as_neo=True)
assert isinstance(data, neo.core.Block)

def test_sim_multi_suj_ephy(self):
"""Test function sim_multi_suj_ephy."""
Expand All @@ -34,3 +44,8 @@ def test_sim_multi_suj_ephy(self):
# mne type
data, _, _ = sim_multi_suj_ephy(n_subjects=5, as_mne=True)
assert all([isinstance(k, EpochsArray) for k in data])
# neo type
if HAVE_NEO:
data, _, _ = sim_multi_suj_ephy(n_subjects=5, as_neo=True)
assert all([isinstance(k, neo.Block) for k in data])

5 changes: 3 additions & 2 deletions frites/simulations/tests/test_sim_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
n_roi = 10
n_sites_per_roi = 1
as_mne = False
as_neo = False
x, roi, time = sim_multi_suj_ephy(n_subjects=n_subjects, n_epochs=n_epochs,
n_times=n_times, n_roi=n_roi,
n_sites_per_roi=n_sites_per_roi,
as_mne=as_mne, modality=modality,
random_state=1)
as_mne=as_mne, as_neo=as_neo,
modality=modality, random_state=1)


class TestSimMi(object): # noqa
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def read(fname):
with open('requirements.txt') as f:
requirements = f.read().splitlines()

core_deps = ['matplotlib', 'networkx', 'numba', 'dcor', 'scikit-learn']
core_deps = ['matplotlib', 'networkx', 'numba', 'dcor', 'scikit-learn', 'neo']
test_deps = ['pytest', 'pytest-sugar', 'pytest-cov', 'codecov']
doc_deps = [
'sphinx!=4.1.0', 'sphinx-gallery', 'pydata-sphinx-theme>=0.6.3',
Expand Down