From 7a3c23968ff0f3b0c715cb5b2455daa45f819d47 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 9 Jan 2025 13:12:59 -0500 Subject: [PATCH 1/7] start of phase_precession notebook --- docs/source/full/day1/phase_precession.md | 323 ++++++++++++++++++++++ 1 file changed, 323 insertions(+) create mode 100644 docs/source/full/day1/phase_precession.md diff --git a/docs/source/full/day1/phase_precession.md b/docs/source/full/day1/phase_precession.md new file mode 100644 index 0000000..a5ba37b --- /dev/null +++ b/docs/source/full/day1/phase_precession.md @@ -0,0 +1,323 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.16.6 +kernelspec: + display_name: Python 3 (ipykernel) + language: python + name: python3 +--- + +Spike-phase coupling and decoding: +Phase precession and hippocampal sequences +========================================== + +In this tutorial we will learn how to apply two methods included in pynapple: filtering and decoding. We'll apply these methods to demonstrate and visualize some well-known physiological properties of hippocampal activity, specifically phase presession of place cells and sequential coordination of place cell activity during theta oscillations. + +Background +---------- +- hippocampus (rat) +- place cells +- LFP and theta oscillation +- phase precession +- theta sequences + +```{code-cell} ipython3 +import math +import os + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import requests +import scipy +import seaborn as sns +import tqdm +import pynapple as nap +import nemos as nmo + +custom_params = {"axes.spines.right": False, "axes.spines.top": False} +sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) +``` + +*** +Downloading the data +-------------------- +The data set we'll be looking at is from the manuscript [Diversity in neural firing dynamics supports both rigid and learned hippocampal sequences](https://www.science.org/doi/10.1126/science.aad1935). In this study, the authors collected electrophisiology data in rats across multiple sites in layer CA1 of hippocampus to extract the LFP alongside spiking activity of many simultaneous pyramidal units. In each recording session, data were collected while the rats explored a novel environment (a linear track), as well as during sleep before and after exploration. In our following analyses, we'll focus on the exploration period of a single rat and recording session. + +First, we need to download the data and save it locally. Since the file size of a recording session can be large from the LFP saved for each recorded channel, we'll use a smaller file that contains the spiking activity and the LFP from a single, representative channel, which is hosted on [OSF](https://osf.io/2dfvp). This smaller file, like the original data, is saved as an [NWB](https://www.nwb.org) file. + +Full dataset: https://dandiarchive.org/dandiset/000044/0.210812.1516 + +(is there a simpler way of doing this? i.e. use nemos?) +- make workshop pooch, make sure its part of the pre-workshop download + +```{code-cell} ipython3 +path = "Achilles_10252013.nwb" +path = nmo.fetch.fetch_data(path) +# if path not in os.listdir("."): +# r = requests.get(f"https://osf.io/2dfvp/download", stream=True) +# block_size = 1024 * 1024 +# with open(path, "wb") as f: +# for data in tqdm.tqdm( +# r.iter_content(block_size), +# unit="MB", +# unit_scale=True, +# total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), +# ): +# f.write(data) +``` + +*** +Loading the data +------------------ +With the file downloaded, we can use the pynapple function `load_file` to load in the data, which is able to handle NWB file types. + +```{code-cell} ipython3 +data = nap.load_file(path) +print(data) +``` + +*** +Selecting a single run +----------------------------------- +for visualization, grab a single run down the linear track (selected in advance) + +```{code-cell} ipython3 +ex_run_ep = data["forward_ep"][9] +``` + +Restrict data to awake epochs: lfp, spikes, and position + +```{code-cell} ipython3 +lfp_run = data["eeg"][:,0].restrict(data["forward_ep"]) +spikes = data["units"].restrict(data["forward_ep"]) +position = data["position"].restrict(data["forward_ep"]) +``` + +*** +Plotting the LFP Activity +----------------------------------- +plot LFP and animal position during trial + +```{code-cell} ipython3 +fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(10, 6), sharex=True) + +# plot LFP +axs[0].plot(lfp_run.restrict(ex_run_ep)) +axs[0].set_title("Local Field Potential on Linear Track") +axs[0].set_ylabel("LFP (a.u.)") +# axs[0].set_xlabel("time (s)") + +# plot animal's position +axs[1].plot(position.restrict(ex_run_ep)) +axs[1].set_title("Animal Position on Linear Track") +axs[1].set_ylabel("Position (cm)") # LOOK UP UNITS +axs[1].set_xlabel("time (s)") +``` + +*** +Getting the Wavelet Decomposition +----------------------------------- +As we would expect, it looks like we have a very strong theta oscillation within our data +- this is a common feature of REM sleep. Let's perform a wavelet decomposition, +as we did in the last tutorial, to see get a more informative breakdown of the +frequencies present in the data. + +We must define the frequency set that we'd like to use for our decomposition. + +```{code-cell} ipython3 +freqs = np.geomspace(5, 200, 25) +``` + +We compute the wavelet transform on our LFP data (only during the example interval). + +double check: FS tracked to https://www.jneurosci.org/content/28/26/6731 methods + +```{code-cell} ipython3 +FS = 1250 # We know from the methods of the paper +cwt_run = nap.compute_wavelet_transform(lfp_run.restrict(ex_run_ep), fs=FS, freqs=freqs) +``` + +*** +Now let's plot the calculated wavelet scalogram. + +```{code-cell} ipython3 +# Define wavelet decomposition plotting function +def plot_timefrequency(freqs, powers, ax=None): + im = ax.imshow(np.abs(powers), aspect="auto") + ax.invert_yaxis() + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (Hz)") + ax.get_xaxis().set_visible(False) + ax.set(yticks=np.arange(len(freqs))[::2], yticklabels=np.rint(freqs[::2])) + ax.grid(False) + return im + +fig, axs = plt.subplots(2, 1, figsize=(10,6), constrained_layout=True, height_ratios=[1.0, 0.3]) +fig.suptitle("Wavelet Decomposition") + +im = plot_timefrequency(freqs, np.transpose(cwt_run[:, :].values), ax=axs[0]) +cbar = fig.colorbar(im, ax=axs[0], orientation="vertical") + +axs[1].plot(lfp_run.restrict(ex_run_ep)) +axs[1].set_ylabel("LFP (a.u.)") +axs[1].set_xlabel("Time (s)") +axs[1].margins(0) +``` + +*** +Filtering Theta +--------------- + +As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.apply_bandpass_filter`. + +```{code-cell} ipython3 +theta_band = nap.apply_bandpass_filter(lfp_run, cutoff=(6.0, 12.0), fs=FS) +``` + +We can plot the original signal and the filtered signal. + +```{code-cell} ipython3 +plt.figure(constrained_layout=True, figsize=(12, 3)) +plt.plot(lfp_run.restrict(ex_run_ep), alpha=0.5) +plt.plot(theta_band.restrict(ex_run_ep)) +plt.xlabel("Time (s)") +plt.show() +``` + +*** +Computing phase +--------------- + +From the filtered signal, it is easy to get the phase using the Hilbert transform. Here we use scipy Hilbert method. + +```{code-cell} ipython3 +from scipy import signal + +phase = np.angle(signal.hilbert(theta_band)) # compute phase with hilbert transform +phase[phase < 0] += 2 * np.pi # wrap to [0,2pi] +theta_phase = nap.Tsd(t=theta_band.t, d=phase) +``` + +Let's plot the phase. + +```{code-cell} ipython3 +fig,axs = plt.subplots(2, 1, figsize=(12,4), constrained_layout=True, sharex=True, height_ratios=[2,1]) + +axs[0].plot(lfp_run.restrict(ex_run_ep), alpha=0.5, label="raw") +axs[0].plot(theta_band.restrict(ex_run_ep), label="filtered") +axs[0].set_ylabel("LFP (a.u.)") + +axs[1].plot(theta_phase.restrict(ex_run_ep), color='r') +axs[1].set_ylabel("Phase (rad)") +axs[1].set_xlabel("Time (s)") +``` + +*** +Finding Phase of Spikes +----------------------- +Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences +of each of the units using the `compute_1d_tuning_curves` function. + +We will start by throwing away cells which do not have a high enough firing rate during our interval. + +```{code-cell} ipython3 +pyr_spikes = spikes[(spikes.rate > 1) & (spikes.rate < 10)] +``` + +compute place fields + +```{code-cell} ipython3 +from scipy.ndimage import gaussian_filter1d +place_fields = nap.compute_1d_tuning_curves(pyr_spikes, position, nb_bins=50) +# filter +place_fields[:] = gaussian_filter1d(place_fields.values, 1, axis=0) +``` + +```{code-cell} ipython3 +fig, axs = plt.subplots(6, 10, figsize=(30, 30)) +for i, (f, fields) in enumerate(place_fields.items()): + idx = np.unravel_index(i, axs.shape) + axs[idx].plot(fields) + axs[idx].set_title(f) +``` + +```{code-cell} ipython3 +plt.figure(constrained_layout=True, figsize = (12, 3)) +for i in range(3): + plt.subplot(1,3,i+1) + plt.plot(phase_modulation.iloc[:,i]) + plt.xlabel("Phase (rad)") + plt.ylabel("Firing rate (Hz)") +plt.show() +``` + +There is clearly a strong modulation for the third neuron. +Finally, we can use the function `value_from` to align each spikes to the corresponding phase position and overlay +it with the LFP. + +```{code-cell} ipython3 +unit = 177 +spike_phase = spikes[unit].value_from(theta_phase) +# spike_position = spikes[unit].value_from(position) +``` + +Let's plot it. + +```{code-cell} ipython3 +fig,axs = plt.subplots(2,1, figsize=(12,6), constrained_layout=True, sharex=True) +axs[0].plot(lfp_run.restrict(ex_run_ep)) +axs[0].plot(theta_band.restrict(ex_run_ep)) +axs[1].plot(theta_phase.restrict(ex_run_ep), alpha=0.5) +axs[1].plot(spike_phase.restrict(ex_run_ep), 'o') +ax = axs[1].twinx() +ax.plot(position.restrict(ex_run_ep)) +``` + +```{code-cell} ipython3 +spike_position = spikes[unit].value_from(position) +plt.subplots(figsize=(3,3)) +plt.plot(spike_phase, spike_position, 'o') +plt.xlabel("Phase (rad)") +plt.ylabel("Position (cm)") +``` + +```{code-cell} ipython3 +# hold out trial from place field computation +run_train = data["forward_ep"].set_diff(ex_run_ep) +position_train = data["position"].restrict(run_train) +place_fields = nap.compute_1d_tuning_curves(spikes, position_train, nb_bins=50) + +# filter place fields +tc = gaussian_filter1d(place_fields.values, 1, axis=0) +place_fields[:] = tc + +# use moving sum of spike counts +ct = spikes.restrict(ex_run_ep).count(0.01).convolve(np.ones(4)) +t = spikes.restrict(ex_run_ep).count(0.01).index +group = nap.TsdFrame(t=t, d=ct, columns=spikes.keys()) + +# decode +_, p = nap.decode_1d(place_fields, group, ex_run_ep, bin_size=0.04) + +# plot +plt.subplots(figsize=(12, 4), constrained_layout=True) +plt.pcolormesh(p.index, p.columns, np.transpose(p)) +plt.plot(position.restrict(ex_run_ep), color="r") +plt.xlabel("Time (s)") +plt.ylabel("Position (cm)") +plt.colorbar(label = "predicted probability") +``` + +:::{card} +Authors +^^^ +Kipp Freud (https://kippfreud.com/) + +Guillaume Viejo + +::: From f050302dcdc109b0f535966256528f821363bbba Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 9 Jan 2025 14:07:02 -0500 Subject: [PATCH 2/7] pooch registry for additional data sets --- src/workshop_utils/__init__.py | 2 +- src/workshop_utils/fetch.py | 61 ++++++++++++++++++++++++++++++++-- 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/workshop_utils/__init__.py b/src/workshop_utils/__init__.py index c3cba03..688941c 100644 --- a/src/workshop_utils/__init__.py +++ b/src/workshop_utils/__init__.py @@ -1,4 +1,4 @@ #!/usr/bin/env python3 -from .fetch import DOWNLOADABLE_FILES +from .fetch import DOWNLOADABLE_FILES, fetch_data from .plotting import * diff --git a/src/workshop_utils/fetch.py b/src/workshop_utils/fetch.py index a89aee8..8617df4 100644 --- a/src/workshop_utils/fetch.py +++ b/src/workshop_utils/fetch.py @@ -2,8 +2,62 @@ import click import nemos as nmo +import pooch -DOWNLOADABLE_FILES = ["allen_478498617.nwb", "Mouse32-140822.nwb", "Achilles_10252013.nwb"] +NEMOS_FILES = [ + "allen_478498617.nwb", + "Mouse32-140822.nwb", + "Achilles_10252013.nwb", +] + +DATA_REGISTRY = { + "Achilles_10252013_EEG.nwb": "a97a69d231e7e91c07e24890225f8fe4636bac054de50345551f32fc46b9efdd", +} + +DATA_URLS = { + "Achilles_10252013_EEG.nwb": "https://osf.io/2dfvp/download", +} + +DATA_ENV = "NEMOS_DATA_DIR" + +DOWNLOADABLE_FILES = NEMOS_FILES + list(DATA_REGISTRY.keys()) + + +def fetch_data(dataset_name, path=None): + """ + Fetch a data set for the neuroRSE workshop, including datasets not included in the NeMoS registry. + This essentially adds a second registry for the workshop, while still using the default download location as NeMoS. + + Parameters + ---------- + dataset_name : str + Name of the data set to fetch. + path : str, optional + Path to the directory where the data set should be stored. If not provided, the default NeMoS cache directory is used. + + Returns + ------- + str + Path to the downloaded data set. + """ + + if dataset_name in NEMOS_FILES: + return nmo.fetch.fetch_data(dataset_name, path=path) + + else: + if path is None: + path = pooch.os_cache("nemos") + + manager = pooch.create( + path=path, + base_url="", + urls=DATA_URLS, + registry=DATA_REGISTRY, + allow_updates="POOCH_ALLOW_UPDATES", + env=DATA_ENV, + ) + + return manager.fetch(dataset_name) @click.command() @@ -17,7 +71,8 @@ def main(): """ for f in DOWNLOADABLE_FILES: - nmo.fetch.fetch_data(f) + fetch_data(f) + -if __name__ == '__main__': +if __name__ == "__main__": main() From 56716402db1241929f4e848622fde245641fb023 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 9 Jan 2025 14:24:28 -0500 Subject: [PATCH 3/7] add NEMOS_DATA_DIR on import --- src/workshop_utils/__init__.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/workshop_utils/__init__.py b/src/workshop_utils/__init__.py index 688941c..c9eab89 100644 --- a/src/workshop_utils/__init__.py +++ b/src/workshop_utils/__init__.py @@ -2,3 +2,9 @@ from .fetch import DOWNLOADABLE_FILES, fetch_data from .plotting import * + +import os +import pathlib + +repo_dir = pathlib.Path(__file__).parent.parent.parent +os.environ["NEMOS_DATA_DIR"] = os.environ.get("NEMOS_DATA_DIR", str(repo_dir / "data")) From 5ff7d8f9232506e10e6ea925df1a3b9513828710 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Thu, 9 Jan 2025 16:45:10 -0500 Subject: [PATCH 4/7] Update phase_precession.md --- docs/source/full/day1/phase_precession.md | 177 ++++++++++++---------- 1 file changed, 98 insertions(+), 79 deletions(-) diff --git a/docs/source/full/day1/phase_precession.md b/docs/source/full/day1/phase_precession.md index a5ba37b..97296b0 100644 --- a/docs/source/full/day1/phase_precession.md +++ b/docs/source/full/day1/phase_precession.md @@ -37,60 +37,49 @@ import scipy import seaborn as sns import tqdm import pynapple as nap -import nemos as nmo +import workshop_utils custom_params = {"axes.spines.right": False, "axes.spines.top": False} sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) ``` *** -Downloading the data +Fetching the data -------------------- The data set we'll be looking at is from the manuscript [Diversity in neural firing dynamics supports both rigid and learned hippocampal sequences](https://www.science.org/doi/10.1126/science.aad1935). In this study, the authors collected electrophisiology data in rats across multiple sites in layer CA1 of hippocampus to extract the LFP alongside spiking activity of many simultaneous pyramidal units. In each recording session, data were collected while the rats explored a novel environment (a linear track), as well as during sleep before and after exploration. In our following analyses, we'll focus on the exploration period of a single rat and recording session. -First, we need to download the data and save it locally. Since the file size of a recording session can be large from the LFP saved for each recorded channel, we'll use a smaller file that contains the spiking activity and the LFP from a single, representative channel, which is hosted on [OSF](https://osf.io/2dfvp). This smaller file, like the original data, is saved as an [NWB](https://www.nwb.org) file. +The full dataset for this study can be accessed on [DANDI](https://dandiarchive.org/dandiset/000044/0.210812.1516). Since the file size of a recording session can be large from the LFP saved for each recorded channel, we'll use a smaller file that contains the spiking activity and the LFP from a single, representative channel, which is hosted on [OSF](https://osf.io/2dfvp). This smaller file, like the original data, is saved as an [NWB](https://www.nwb.org) file. -Full dataset: https://dandiarchive.org/dandiset/000044/0.210812.1516 - -(is there a simpler way of doing this? i.e. use nemos?) -- make workshop pooch, make sure its part of the pre-workshop download +If you ran the workshop setup script, you should have this file downloaded already. If not, the function we'll use to fetch it will download it for you. This function is called `fetch_data`, and can be imported from the `workshop_utils` module. ```{code-cell} ipython3 -path = "Achilles_10252013.nwb" -path = nmo.fetch.fetch_data(path) -# if path not in os.listdir("."): -# r = requests.get(f"https://osf.io/2dfvp/download", stream=True) -# block_size = 1024 * 1024 -# with open(path, "wb") as f: -# for data in tqdm.tqdm( -# r.iter_content(block_size), -# unit="MB", -# unit_scale=True, -# total=math.ceil(int(r.headers.get("content-length", 0)) // block_size), -# ): -# f.write(data) +from workshop_utils import fetch_data + +path = fetch_data("Achilles_10252013_EEG.nwb") ``` -*** -Loading the data ------------------- -With the file downloaded, we can use the pynapple function `load_file` to load in the data, which is able to handle NWB file types. +This function will give us the file path to where the data is stored. We can then use the pynapple function `load_file` to load in the data, which is able to handle the NWB file type. ```{code-cell} ipython3 data = nap.load_file(path) print(data) ``` -*** -Selecting a single run ------------------------------------ -for visualization, grab a single run down the linear track (selected in advance) +What this gives you is a dictionary of pynapple objects that have been inferred from the NWB file. This dictionary contains the following fields: +- `units`: a `TsGroup` with each units spike times as well as metadata about each unit (i.e. location, shank, and cell type). This dataset contains 137 units all in CA1. +- `rem`: an `IntervalSet` of REM sleep epochs, with 3 occuring before exploration and 1 occuring after. +- `nrem`: an `IntervalSet` of nREM sleep epochs, with 6 occuring before exploration and 5 occuring after. +- `forward_ep`: an `IntervalSet` containing each time window when the animal crossed the linear track in one direction. There are a total of 84 traversals in this session. +- `eeg`: a `TsdFrame` containing an LFP voltage traces for a single representative channel in CA1. +- `theta_phase`: a `Tsd` with the computed theta phase of the LFP used in the study. We will be computing this ourselves. +- `position`: a `Tsd` containing the linearized position -```{code-cell} ipython3 -ex_run_ep = data["forward_ep"][9] -``` ++++ -Restrict data to awake epochs: lfp, spikes, and position +*** +Filtering the data +------------------ +For the following exercises, we'll only focus on the exploration epochs contained in `forward_ep`. Therefore, when extracting the LFP, spikes, and position, we can use `restrict()` with the `forward_ep` IntervalSet to subselect the data. ```{code-cell} ipython3 lfp_run = data["eeg"][:,0].restrict(data["forward_ep"]) @@ -98,22 +87,34 @@ spikes = data["units"].restrict(data["forward_ep"]) position = data["position"].restrict(data["forward_ep"]) ``` +For visualization, we'll look at a single run down the linear track. For a good example, we'll start by looking at run 10 (python index 9). It is encouraged, however, to repeat these exercises on additional runs! + +```{code-cell} ipython3 +ex_run_ep = data["forward_ep"][9] +``` + *** -Plotting the LFP Activity ------------------------------------ -plot LFP and animal position during trial +Plotting the LFP and animal position +------------------------------------ +To get a sense of what the LFP looks like while the animal runs down the linear track, we can plot each variable, `lfp_run` and `position`, side-by-side. We'll want to further restrict each variable to our run of interest stored in `ex_run_ep`. + +```{code-cell} ipython3 +ex_lfp_run = lfp_run.restrict(ex_run_ep) +ex_position = position.restrict(ex_run_ep) +``` + +By default, plotting Tsd objects will use the time index on the x-axis. However, for a more interpretable time axis, we'll subtract the first time index from each variable's time indices and pass it as the first argument in matplotlib's `plot`. This will give the relative time elapsed on the current run. ```{code-cell} ipython3 fig, axs = plt.subplots(2, 1, constrained_layout=True, figsize=(10, 6), sharex=True) # plot LFP -axs[0].plot(lfp_run.restrict(ex_run_ep)) +axs[0].plot(ex_lfp_run.index - ex_lfp_run.index[0], ex_lfp_run) axs[0].set_title("Local Field Potential on Linear Track") axs[0].set_ylabel("LFP (a.u.)") -# axs[0].set_xlabel("time (s)") # plot animal's position -axs[1].plot(position.restrict(ex_run_ep)) +axs[1].plot(ex_position.index - ex_position.index[0], ex_position) axs[1].set_title("Animal Position on Linear Track") axs[1].set_ylabel("Position (cm)") # LOOK UP UNITS axs[1].set_xlabel("time (s)") @@ -122,78 +123,79 @@ axs[1].set_xlabel("time (s)") *** Getting the Wavelet Decomposition ----------------------------------- -As we would expect, it looks like we have a very strong theta oscillation within our data -- this is a common feature of REM sleep. Let's perform a wavelet decomposition, -as we did in the last tutorial, to see get a more informative breakdown of the -frequencies present in the data. +As we would expect, there is a strong theta oscillation dominating the LFP while the animal runs down the track. To illustrate this further, we'll perform a wavelet decomposition on the LFP trace during this run. -We must define the frequency set that we'd like to use for our decomposition. +DEFINE WAVELET DECOMPOSITION + +We must define the frequency set that we'd like to use for our decomposition. We can do this with the numpy function `np.geomspace`, which returns numbers evenly spaced on a log scale. We pass the lower frequency, the upper frequency, and number of samples as positional arguments. ```{code-cell} ipython3 -freqs = np.geomspace(5, 200, 25) +# 25 log-spaced samples between 5Hz and 200Hz +freqs = np.geomspace(5, 200, 100) ``` -We compute the wavelet transform on our LFP data (only during the example interval). +We can now compute the wavelet transform on our LFP data during the example run using the pynapple function `nap.compute_wavelet_trasform`, which takes the time series and array of frequencies as positional arguments. Optionally, we can pass the keyword argument `fs` to provide the the sampling frequency, which is known to be 1250Hz from the study methods. -double check: FS tracked to https://www.jneurosci.org/content/28/26/6731 methods +double check: FS back tracked to https://www.jneurosci.org/content/28/26/6731 methods ```{code-cell} ipython3 -FS = 1250 # We know from the methods of the paper -cwt_run = nap.compute_wavelet_transform(lfp_run.restrict(ex_run_ep), fs=FS, freqs=freqs) +sample_freq = 1250 # We know from the methods of the paper +cwt_run = nap.compute_wavelet_transform(lfp_run.restrict(ex_run_ep), freqs, fs=sample_freq) ``` -*** -Now let's plot the calculated wavelet scalogram. +If `fs` is not provided, it can be inferred from the time series `rate` attribute, which matches what was pulled from the methods + +```{code-cell} ipython3 +print(ex_lfp_run.rate) +``` + +We can visualize the results by plotting a heat map of the calculated wavelet scalogram. ```{code-cell} ipython3 -# Define wavelet decomposition plotting function -def plot_timefrequency(freqs, powers, ax=None): - im = ax.imshow(np.abs(powers), aspect="auto") - ax.invert_yaxis() - ax.set_xlabel("Time (s)") - ax.set_ylabel("Frequency (Hz)") - ax.get_xaxis().set_visible(False) - ax.set(yticks=np.arange(len(freqs))[::2], yticklabels=np.rint(freqs[::2])) - ax.grid(False) - return im - -fig, axs = plt.subplots(2, 1, figsize=(10,6), constrained_layout=True, height_ratios=[1.0, 0.3]) +fig, axs = plt.subplots(2, 1, figsize=(10,6), constrained_layout=True, height_ratios=[1.0, 0.3], sharex=True) fig.suptitle("Wavelet Decomposition") -im = plot_timefrequency(freqs, np.transpose(cwt_run[:, :].values), ax=axs[0]) -cbar = fig.colorbar(im, ax=axs[0], orientation="vertical") +t = ex_lfp_run.index - ex_lfp_run.index[0] +power = np.abs(cwt_run.values) +cax = axs[0].pcolormesh(t, freqs, power.T) +axs[0].set(ylabel="Frequency (Hz)", yscale='log', yticks=freqs[::10], yticklabels=np.rint(freqs[::10])); +axs[0].minorticks_off() +fig.colorbar(cax,label="Power") -axs[1].plot(lfp_run.restrict(ex_run_ep)) -axs[1].set_ylabel("LFP (a.u.)") -axs[1].set_xlabel("Time (s)") +axs[1].plot(t, ex_lfp_run) +axs[1].set(ylabel="LFP (a.u.)", xlabel="Time(s)") axs[1].margins(0) ``` *** -Filtering Theta ---------------- +Filtering for theta +------------------- +We can extract the theta oscillation by applying a bandpass filter on the raw LFP. To do this, we use the pynapple function `nap.apply_bandpass_filter`, which takes the time series as the first argument and the frequency cutoffs as the second argument. Similarly to `nap.compute_wavelet_transorm`, we can optinally pass the sampling frequency keyword argument `fs`. -As expected, there is a strong 8Hz component during REM sleep. We can filter it using the function `nap.apply_bandpass_filter`. +Conveniently, this function will recognize and handle splits in the subsampled data (i.e. applying the filtering separately to discontinuous epochs), so we can pass the LFP for all the runs together. ```{code-cell} ipython3 -theta_band = nap.apply_bandpass_filter(lfp_run, cutoff=(6.0, 12.0), fs=FS) +theta_band = nap.apply_bandpass_filter(lfp_run, (6.0, 12.0), fs=sample_freq) ``` -We can plot the original signal and the filtered signal. +We can visualize the output by plotting the filtered signal with the original signal. ```{code-cell} ipython3 plt.figure(constrained_layout=True, figsize=(12, 3)) -plt.plot(lfp_run.restrict(ex_run_ep), alpha=0.5) -plt.plot(theta_band.restrict(ex_run_ep)) -plt.xlabel("Time (s)") -plt.show() +plt.plot(t, ex_lfp_run, alpha=0.5, label="raw") +plt.plot(t, theta_band.restrict(ex_run_ep), label="filtered") +plt.ylabel("Time (s)") +plt.xlabel("LFP (a.u.)") +plt.title("Bandpass filter for theta oscillations (6-12 Hz)") +plt.legend(); ``` *** Computing phase --------------- +In order to examine phase precession in place cells, we need to extract the phase of theta from the filtered signal. We can do this by taking the angle of the [Hilbert transform](https://en.wikipedia.org/wiki/Hilbert_transform). -From the filtered signal, it is easy to get the phase using the Hilbert transform. Here we use scipy Hilbert method. +The `signal` module of `scipy` includes a function to perform the Hilbert transform, after which we can use the numpy function `np.angle` to extract the angle. ```{code-cell} ipython3 from scipy import signal @@ -211,15 +213,32 @@ fig,axs = plt.subplots(2, 1, figsize=(12,4), constrained_layout=True, sharex=Tru axs[0].plot(lfp_run.restrict(ex_run_ep), alpha=0.5, label="raw") axs[0].plot(theta_band.restrict(ex_run_ep), label="filtered") axs[0].set_ylabel("LFP (a.u.)") +axs[0].legend() axs[1].plot(theta_phase.restrict(ex_run_ep), color='r') axs[1].set_ylabel("Phase (rad)") axs[1].set_xlabel("Time (s)") ``` +```{code-cell} ipython3 +fig,ax = plt.subplots(figsize=(12,2), constrained_layout=True) #, sharex=True, height_ratios=[2,1]) + +ax.plot(t, theta_phase.restrict(ex_run_ep), color='r', label="phase") +ax.set_ylabel("Phase (rad)") +ax.set_xlabel("Time (s)") +ax = ax.twinx() +ax.plot(t, theta_band.restrict(ex_run_ep), alpha=0.5, label="filtered LFP") +ax.set_ylabel("LFP (a.u.)") +fig.legend() +``` + +cycle "resets" at peaks + ++++ + *** -Finding Phase of Spikes ------------------------ +Identifying place-selective cells +--------------------------------- Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences of each of the units using the `compute_1d_tuning_curves` function. From 9a7165fdd0c680f0ddf92cf820c34f70ddd175d0 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 10 Jan 2025 15:59:50 -0500 Subject: [PATCH 5/7] animated 1d convolution plot --- src/workshop_utils/plotting.py | 225 ++++++++++++++++++++++++++++++++- 1 file changed, 224 insertions(+), 1 deletion(-) diff --git a/src/workshop_utils/plotting.py b/src/workshop_utils/plotting.py index b8713da..203e7c9 100644 --- a/src/workshop_utils/plotting.py +++ b/src/workshop_utils/plotting.py @@ -6,8 +6,11 @@ import numpy as np from typing import Union from numpy.typing import NDArray +from matplotlib.animation import FuncAnimation + + +__all__ = ["plot_features", "animate_1d_convolution"] -__all__ = ["plot_features"] def plot_features( input_feature: Union[nap.Tsd, nap.TsdFrame, nap.TsdTensor, NDArray], @@ -56,3 +59,223 @@ def plot_features( plt.tight_layout() return fig + + +import pynapple as nap +import pandas as pd +import matplotlib.pyplot as plt +import numpy as np +from typing import Union +from numpy.typing import NDArray +from matplotlib.animation import FuncAnimation +from matplotlib.animation import FuncAnimation + + +class Plot1DConvolution: + """ + Class to plot an animation of convolving some 1D kernel with some Tsd array. + + Parameters + ---------- + tsd : + The Tsd object to convolve with the kernel. + kernel : + The 1D kernel to convolve with the array. + index : + The time index. Taken from the Tsd object if not provided. + start : + The index along the x-axis to start the animation. Defaults to the start of the window. + interval : + The interval between frames in milliseconds. + figsize : + The figure size. + ylim : + The y-axis limits. + xlabel : + The x-axis label. + ylabel : + The y-axis label. + tsd_label : + The legend label for the Tsd array + kernel_label : + The legend label for the kernel + conv_label : + The legend label for the convolution output + split_kernel_yaxis : + Whether or not to have a separate y-axis (i.e. use twinx()) for plotting the kernel. Useful if the kernel is magnitudes smaller/larger than the Tsd. + """ + + def __init__( + self, + tsd: nap.Tsd, + kernel: NDArray, + index: NDArray = None, + start: int = 0, + interval: float = 100, + figsize: tuple = (10, 3), + ylim: float = None, + xlabel: str = "Time (s)", + ylabel: str = "Count", + tsd_label: str = "original array", + kernel_label: str = "kernel", + conv_label: str = "convolution", + split_kernel_yaxis: bool = False, + ): + self.tsd = tsd + self.kernel = kernel + if index is None: + self.index = tsd.index.values + else: + self.index = index + self.start = start + self.conv = tsd.convolve(kernel) + self.conv_viz = np.zeros_like(tsd) + self.frames = len(tsd) - start + self.interval = interval + if ylim is None: + if split_kernel_yaxis: + ymin = np.min((self.tsd.min(), self.conv.min())) + ymax = np.max((self.tsd.max(), self.conv.max())) + else: + ymin = np.min((self.tsd.min(), self.conv.min(), self.kernel.min())) + ymax = np.max((self.tsd.max(), self.conv.max(), self.kernel.max())) + ylim = (ymin, ymax) + self.ylim = ylim + self.xlabel = xlabel + self.ylabel = ylabel + self.tsd_label = tsd_label + self.kernel_label = kernel_label + self.conv_label = conv_label + self.split_kernel_yaxis = split_kernel_yaxis + ( + self.fig, + self.kernel_line, + self.conv_line, + self.conv_area, + self.top_idx_line, + self.bottom_idx_line, + ) = self.setup(figsize) + + def setup(self, figsize): + """ + Initialization of the plot. + """ + # initial placement of kernel + kernel_full = np.zeros_like(self.tsd) + kidx, kmid = self.kernel_bounds(0) + if np.any(kidx): + kernel_full[kidx] = self.kernel[: len(kidx)] + + fig, axs = plt.subplots(2, 1, figsize=figsize, sharex=True, sharey=True) + + ### top plot ### + ax = axs[0] + # this is fixed + ax.plot(self.index, self.tsd, label="original array") + + # initial visible convolution output and top center line + if kmid >= 0: + self.conv_viz[: kmid + 1] = self.conv[: kmid + 1] + cx = self.index[kmid] + else: + cx = self.index[0] + top_idx_line = ax.plot((cx, cx), self.ylim, "--", color="black", alpha=0.5)[0] + + # initial filled area + conv_area = ax.fill_between( + self.index, + np.zeros_like(self.tsd), + self.tsd * kernel_full.values, + alpha=0.5, + color="green", + ) + + # initial kernel plot + if self.split_kernel_yaxis: + ax = ax.twinx() + ax.set_ylabel(self.kernel_label) + ax.set_ylim((kernel_full.min(), kernel_full.max())) + kernel_line = ax.plot( + self.index, kernel_full, color="orange", label=self.kernel_label + )[0] + + ### bottom plot ### + ax = axs[1] + # initial convolution output and bottom plot center line + conv_line = ax.plot( + self.index, self.conv_viz, color="green", label=self.conv_label + )[0] + bottom_idx_line = ax.plot((cx, cx), self.ylim, "--", color="black", alpha=0.5)[ + 0 + ] + + ax.set_ylim(self.ylim) + + fig.legend() + fig.supxlabel(self.xlabel) + fig.supylabel(self.ylabel) + plt.tight_layout() + + return fig, kernel_line, conv_line, conv_area, top_idx_line, bottom_idx_line + + def update(self, frame): + if frame > 0: + # place kernel at shifted location based on frame number + kernel_full = np.zeros_like(self.tsd) + kidx, kmid = self.kernel_bounds(frame) + kernel_full[kidx] = self.kernel[: len(kidx)] + self.kernel_line.set_ydata(kernel_full) + + # update visible convolution output + if kmid >= 0: + self.conv_viz[kmid] = self.conv[kmid] + self.conv_line.set_ydata(self.conv_viz) + self.top_idx_line.set_xdata((self.index[kmid], self.index[kmid])) + self.bottom_idx_line.set_xdata((self.index[kmid], self.index[kmid])) + + # update filled area + self.conv_area.set_data( + self.index, np.zeros_like(self.tsd), self.tsd * kernel_full.values + ) + + def run(self): + anim = FuncAnimation( + self.fig, self.update, self.frames, interval=self.interval, repeat=True + ) + plt.close(self.fig) + return anim + + def kernel_bounds(self, frame): + # kernel bounds set to the left of the frame index and start location + kmin = frame + self.start - len(self.kernel) + kmax = frame + self.start + + # kernel indices no less than 0 and no more than the length of the Tsd + kidx = np.arange(np.max((kmin, 0)), np.min((kmax, len(self.tsd)))) + + # convolution output w.r.t. the midpoint of where the kernel is placed + kmid = kmin + np.floor(len(self.kernel) / 2).astype(int) + + return kidx, kmid + + +def animate_1d_convolution(tsd: nap.Tsd, kernel: NDArray, **kwargs): + """ + Animate the convolution of a 1D kernel with some Tsd array. + + Parameters + ---------- + tsd : nap.Tsd + The Tsd object to be convolved. + kernel : np.ndarray + The 1D kernel to convolve with the array. + **kwargs + Additional keyword arguments to pass to Plot1DConvolution. + + Returns + ------- + matplotlib.animation.FuncAnimation + The animation object. + """ + anim = Plot1DConvolution(tsd, kernel, **kwargs) + return anim.run() From e315915343ce5affdb8eb972c45477cd37a9bef5 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 13 Jan 2025 13:00:38 -0500 Subject: [PATCH 6/7] finished first draft of phase precession notebook --- docs/source/full/day1/phase_precession.md | 284 +++++++++++++++++----- src/workshop_utils/plotting.py | 2 +- 2 files changed, 219 insertions(+), 67 deletions(-) diff --git a/docs/source/full/day1/phase_precession.md b/docs/source/full/day1/phase_precession.md index 97296b0..3246611 100644 --- a/docs/source/full/day1/phase_precession.md +++ b/docs/source/full/day1/phase_precession.md @@ -17,6 +17,18 @@ Phase precession and hippocampal sequences In this tutorial we will learn how to apply two methods included in pynapple: filtering and decoding. We'll apply these methods to demonstrate and visualize some well-known physiological properties of hippocampal activity, specifically phase presession of place cells and sequential coordination of place cell activity during theta oscillations. +Pynapple functions used: +- load_file +- restrict +- compute_wavelet_transform +- apply_bandpass_filter +- compute_1d_tuning_curves +- value_from +- set_diff +- count +- convolve +- decode_1d + Background ---------- - hippocampus (rat) @@ -39,8 +51,9 @@ import tqdm import pynapple as nap import workshop_utils -custom_params = {"axes.spines.right": False, "axes.spines.top": False} -sns.set_theme(style="ticks", palette="colorblind", font_scale=1.5, rc=custom_params) +# necessary for animation +import nemos as nmo +plt.style.use(nmo.styles.plot_style) ``` *** @@ -53,9 +66,7 @@ The full dataset for this study can be accessed on [DANDI](https://dandiarchive. If you ran the workshop setup script, you should have this file downloaded already. If not, the function we'll use to fetch it will download it for you. This function is called `fetch_data`, and can be imported from the `workshop_utils` module. ```{code-cell} ipython3 -from workshop_utils import fetch_data - -path = fetch_data("Achilles_10252013_EEG.nwb") +path = workshop_utils.fetch_data("Achilles_10252013_EEG.nwb") ``` This function will give us the file path to where the data is stored. We can then use the pynapple function `load_file` to load in the data, which is able to handle the NWB file type. @@ -172,7 +183,7 @@ Filtering for theta ------------------- We can extract the theta oscillation by applying a bandpass filter on the raw LFP. To do this, we use the pynapple function `nap.apply_bandpass_filter`, which takes the time series as the first argument and the frequency cutoffs as the second argument. Similarly to `nap.compute_wavelet_transorm`, we can optinally pass the sampling frequency keyword argument `fs`. -Conveniently, this function will recognize and handle splits in the subsampled data (i.e. applying the filtering separately to discontinuous epochs), so we can pass the LFP for all the runs together. +Conveniently, this function will recognize and handle splits in the epoched data (i.e. applying the filtering separately to discontinuous epochs), so we can pass the LFP for all the runs together. ```{code-cell} ipython3 theta_band = nap.apply_bandpass_filter(lfp_run, (6.0, 12.0), fs=sample_freq) @@ -205,20 +216,7 @@ phase[phase < 0] += 2 * np.pi # wrap to [0,2pi] theta_phase = nap.Tsd(t=theta_band.t, d=phase) ``` -Let's plot the phase. - -```{code-cell} ipython3 -fig,axs = plt.subplots(2, 1, figsize=(12,4), constrained_layout=True, sharex=True, height_ratios=[2,1]) - -axs[0].plot(lfp_run.restrict(ex_run_ep), alpha=0.5, label="raw") -axs[0].plot(theta_band.restrict(ex_run_ep), label="filtered") -axs[0].set_ylabel("LFP (a.u.)") -axs[0].legend() - -axs[1].plot(theta_phase.restrict(ex_run_ep), color='r') -axs[1].set_ylabel("Phase (rad)") -axs[1].set_xlabel("Time (s)") -``` +Let's plot the phase on top of the filtered LFP signal. ```{code-cell} ipython3 fig,ax = plt.subplots(figsize=(12,2), constrained_layout=True) #, sharex=True, height_ratios=[2,1]) @@ -232,106 +230,260 @@ ax.set_ylabel("LFP (a.u.)") fig.legend() ``` -cycle "resets" at peaks +We can see that cycle "resets" (i.e. goes from $2\pi$ to $0$) at peaks of the theta oscillation. +++ *** Identifying place-selective cells --------------------------------- -Now that we have the phase of our theta wavelet, and our spike times, we can find the phase firing preferences -of each of the units using the `compute_1d_tuning_curves` function. +In order to identify phase precession in individual units, we need to know their place selectivity. We can find place firing preferences +of each unit by using the `compute_1d_tuning_curves` function. -We will start by throwing away cells which do not have a high enough firing rate during our interval. +We'll start by narrowing down our cells to lower-firing, putative pyramidal cells, or units that fire between 1 and 10 Hz on average, to reduce to units that are likely selective to a single location. These units will give us the clearest examples of phase precession. ```{code-cell} ipython3 pyr_spikes = spikes[(spikes.rate > 1) & (spikes.rate < 10)] ``` -compute place fields +Using these units and the position data, we can compute their place fields by using `nap.compute_1d_tuning_curves`. The first argument will be the TsGroup of our spikes, the second argument the Tsd feature of position, and the third argument the number of evenly-spaced bins in which to split the feature for the tuning curves. This function will return a `pandas.DataFrame`, where the index is the corresponding feature value, and the column is the unit label. ```{code-cell} ipython3 from scipy.ndimage import gaussian_filter1d -place_fields = nap.compute_1d_tuning_curves(pyr_spikes, position, nb_bins=50) -# filter +place_fields = nap.compute_1d_tuning_curves(pyr_spikes, position, 50) +``` + +It is customary to apply a gaussian smoothing filter to place fields, which we can do by applying the `scipy` function `gaussian_filter1d`. + +```{code-cell} ipython3 +# apply a smoothing filter place_fields[:] = gaussian_filter1d(place_fields.values, 1, axis=0) ``` +We can use a subplot array to visualize the place fields of many units simultaneously. Let's do this for the first 50 units. + ```{code-cell} ipython3 -fig, axs = plt.subplots(6, 10, figsize=(30, 30)) -for i, (f, fields) in enumerate(place_fields.items()): +fig, axs = plt.subplots(10, 5, figsize=(12, 15), sharex=True, constrained_layout=True) +for i, (f, fields) in enumerate(place_fields.iloc[:,:50].items()): idx = np.unravel_index(i, axs.shape) axs[idx].plot(fields) axs[idx].set_title(f) -``` -```{code-cell} ipython3 -plt.figure(constrained_layout=True, figsize = (12, 3)) -for i in range(3): - plt.subplot(1,3,i+1) - plt.plot(phase_modulation.iloc[:,i]) - plt.xlabel("Phase (rad)") - plt.ylabel("Firing rate (Hz)") -plt.show() +fig.supylabel("Firing rate (Hz)") +fig.supxlabel("Position (cm)") ``` -There is clearly a strong modulation for the third neuron. -Finally, we can use the function `value_from` to align each spikes to the corresponding phase position and overlay -it with the LFP. +We can see spatial selectivity in each of the units; across the population, we have firing fields tiling the entire linear track. + +To look at phase precession, we'll zoom in on unit 177. This unit has a single strong firing field in the middle of the track, which will be conducive for visualizing phase precession. + ++++ + +*** +Computing phase precession within a single unit +----------------------------------------------- +As a first visualization of phase precession, we'll look at a single traversal of the linear track. We'll want the corresponding phase of theta at which the unit fires as the animal is running down the track, which we can compute using a pynapple object's method `value_from`. For our spiking data, this will find the phase value closest in time to each spike. ```{code-cell} ipython3 unit = 177 spike_phase = spikes[unit].value_from(theta_phase) -# spike_position = spikes[unit].value_from(position) ``` -Let's plot it. +To see the results, let's plot the theta phase, the spike phase, and the animal's position across the run, as well as the unit's place field as a reminder of it's spatial preference. (Since the relationship between a single run's time and position is nearly linear, the x-axis of position and the x-axis of time will be well-aligned.) ```{code-cell} ipython3 -fig,axs = plt.subplots(2,1, figsize=(12,6), constrained_layout=True, sharex=True) -axs[0].plot(lfp_run.restrict(ex_run_ep)) -axs[0].plot(theta_band.restrict(ex_run_ep)) -axs[1].plot(theta_phase.restrict(ex_run_ep), alpha=0.5) -axs[1].plot(spike_phase.restrict(ex_run_ep), 'o') -ax = axs[1].twinx() -ax.plot(position.restrict(ex_run_ep)) +fig,axs = plt.subplots(2,1 , figsize=(10,4), constrained_layout=True) +axs[0].plot(theta_phase.restrict(ex_run_ep), alpha=0.5, label="theta phase") +axs[0].plot(spike_phase.restrict(ex_run_ep), 'o', label="spike phase") +axs[0].set_ylabel("Phase (rad)") +axs[0].set_xlabel("Time (s)") +axs[0].set_title("Unit 177 spike phase and animal position") +ax = axs[0].twinx() +ax.plot(ex_position, color="green", label="position") +ax.set_ylabel("Position (cm)") +fig.legend() + +axs[1].plot(place_fields[unit]) +axs[1].set_ylabel("Firing rate (Hz)") +axs[1].set_xlabel("Position (cm)") +axs[1].set_title("Unit 177 place field") ``` +As expected, unit 177 will preferentially spike (orange dots) as the animal runs through the middle of the track. Additionally, you should be able to see a negative trend in the spike phase as the animal move's further along the track. This phemomena is what is called phase precession: as an animal runs through the place field of a single unit, that unit will spike at *late* phases of theta (higher radians) in *earlyr* positions in the field, and fire at *early* phases of theta (lower radians) in *late* positions in the field. + +We can observe this phenomena on average across all runs by relating the spike phase to the spike position. Similar to before, we'll use the pynapple object method `value_from` to additionally find the animal position closest in time to each spike. + ```{code-cell} ipython3 spike_position = spikes[unit].value_from(position) -plt.subplots(figsize=(3,3)) -plt.plot(spike_phase, spike_position, 'o') -plt.xlabel("Phase (rad)") -plt.ylabel("Position (cm)") ``` +Now we can plot the spike phase against the spike position in a scatter plot. + +```{code-cell} ipython3 +plt.subplots(figsize=(4,3)) +plt.plot(spike_position, spike_phase, 'o') +plt.ylabel("Phase (rad)") +plt.xlabel("Position (cm)") +``` + +Similar to what we saw in a single run, there is a negative relationship between theta phase and field position, characteristic of phase precession. + ++++ + +*** +Decoding position from spiking activity +--------------------------------------- +Next we'll do a popular analysis in the rat hippocampal sphere: Bayesian decoding. This analysis is an elegent application of Bayes' rule in predicting the animal's location (or other behavioral variables) from neural activity at some point in time. + +### Background +Recall Bayes' rule, written here in terms of our relevant variables: + +$$P(position|spikes) = \frac{P(position)P(spikes|position)}{P(spikes)}$$ + +Our goal is to compute the unknown posterior $P(position|spikes)$ given known prior $P(position)$ and known likelihood $P(spikes|position)$. + +$P(position)$, also known as the *occupancy*, is the probability that the animal is occupying some position. This can be computed exactly by the proportion of the total time spent at each position, but in many cases it is sufficient to estimate the occupancy as a uniform distribution, i.e. it is equally likely for the animal to occupy any location. + +The next term, $P(spikes|position)$, which is the probability of seeing some sequence of spikes across all neurons at some position. Computing this relys on the following assumptions: +1. Neurons fire according to a Poisson process (i.e. their spiking activity follows a Poisson distribution) +2. Neurons fire independently from one another. + +While neither of these assumptions are strictly true, they are generally reasonable for pyramidal cells in hippocampus and allow us to simplify our computation of $P(spikes|position)$ + +The first assumption gives us an equation for $P(spikes|position)$ for a single neuron, which we'll call $P(spikes_i|position)$ to differentiate it from $P(spikes|position) = P(spikes_1,spikes_2,...,spikes_i,...,spikes_N|position) $, or the total probability across all $N$ neurons. The equation we get is that of the Poisson distribution: +$$ +P(spikes_i|position) = \frac{(\tau f_i(position))^n e^{-\tau f_i(position)}}{n!} +$$ +where $f_i(position)$ is the firing rate of the neuron at position $(position)$ (i.e. the tuning curve), $\tau$ is the width of the time window over which we're computing the probability, and $n$ is the total number of times the neuron spiked in the time window of interest. + +The second assumptions allows us to simply combine the probabilities of individual neurons. Recall the product rule for independent events: $P(A,B) = P(A)P(B)$ if $A$ and $B$ are independent. Treating neurons as independent, then, gives us the following: +$$ +P(spikes|position) = \prod_i P(spikes_i|position) +$$ + +The final term, $P(spikes)$, is inferred indirectly using the law of total probability: + +$$P(spikes) = \sum_{position}P(position,spikes) = \sum_{position}P(position)P(spikes|position)$$ + +Another way of putting it is $P(spikes)$ is the normalization factor such that $\sum_{position} P(position|spikes) = 1$, which is achived by dividing the numerator by its sum. + +If this method looks daunting, we have some good news: pynapple has it implemented already in the function `nap.decode_1d` for decoding a single dimension (or `nap.decode_2d` for two dimensions). All we'll need are the spikes, the tuning curves, and the width of the time window $\tau$. + +### Cross-validation + +Generally this method is cross-validated, which means you train the model on one set of data and test the model on a different, held-out data set. For Bayesian decoding, the "model" refers to the model *likelihood*, which is computed from the tuning curves. + +We want to decode the example run we've been using throughout this exercise; therefore, our training set should omit this run before computing the tuning curves. We can do this by using the IntervalSet method `set_diff`, to take out the example run epoch from all run epochs. + ```{code-cell} ipython3 # hold out trial from place field computation run_train = data["forward_ep"].set_diff(ex_run_ep) +``` + +Next, we'll restrict our data to these training epochs and re-compute the place fields using `nap.compute_1d_tuning_curves`. Similar to before, we'll applying a Gaussian smoothing filter to the place fields, which will smooth our decoding results down the line. + +```{code-cell} ipython3 position_train = data["position"].restrict(run_train) place_fields = nap.compute_1d_tuning_curves(spikes, position_train, nb_bins=50) +place_fields[:] = gaussian_filter1d(place_fields.values, 1, axis=0) +``` + +### Run decoder + +This is the minumum needed to run the `nap.decode_1d` function. The first input will be our tuning curves (place fields), the second input is the `TsGroup` of spike times corresponding to units in the tuning curve DataFrame (this can also be a `TsdFrame` of spike counts), the third input is the epoch we want to decode, and the fourth input is the bin size, or the time resolution $\tau$ at which to decode. + +```{code-cell} ipython3 +decoded_position, decoded_prob = nap.decode_1d(place_fields, spikes, ex_run_ep, bin_size=0.2) +``` + +The first output is the inferred position from the decoder, and the second output is the posterior distribution, giving the probability distribution of position given the spiking activity. (The decoded position is simply the position at which the probability is greatest.) + +Let's plot the posterior distribution and overlay the animal's true position. + +```{code-cell} ipython3 +plt.subplots(figsize=(10, 4), constrained_layout=True) +plt.pcolormesh(decoded_position.index,place_fields.index,np.transpose(decoded_prob)) +plt.plot(decoded_position, color="green") +plt.plot(ex_position, color="red") +``` + +The decoder does a reasonable job at following the animals true position, but gets worse at shorter bin sizes. + +```{code-cell} ipython3 +decoded_position, decoded_prob = nap.decode_1d(place_fields, spikes, ex_run_ep, bin_size=0.05) +plt.subplots(figsize=(10, 4), constrained_layout=True) +plt.pcolormesh(decoded_position.index,place_fields.index,np.transpose(decoded_prob)) +plt.plot(decoded_position, color="green") +plt.plot(ex_position, color="red") +``` -# filter place fields -tc = gaussian_filter1d(place_fields.values, 1, axis=0) -place_fields[:] = tc +### Smooth spike counts for decoder -# use moving sum of spike counts -ct = spikes.restrict(ex_run_ep).count(0.01).convolve(np.ones(4)) -t = spikes.restrict(ex_run_ep).count(0.01).index -group = nap.TsdFrame(t=t, d=ct, columns=spikes.keys()) +One way to improve our estimation at shorter bin sizes is to instead use *sliding windows* to bin our data. This allows us to combine the accuracy of a larger bin size with the resolution of a smaller bin size by essentially smoothing the spike counts. This is a feature that will be added in a future version of pynapple, but we can still apply it ourselves by providing the spike counts directly to `nap.decode_1d` as a `TsdFrame`. -# decode -_, p = nap.decode_1d(place_fields, group, ex_run_ep, bin_size=0.04) +Let's say we want a sliding window of $200 ms$ that shifts by $50 ms$. We can compute this efficiently by first binning at the smaller $50 s$ bin size, which we can do by applying the pynapple object method `count`. -# plot +```{code-cell} ipython3 +counts = spikes.restrict(ex_run_ep).count(0.05) +``` + +Next, we apply a moving sum on each set of $200 ms / 50 ms = 4$ adjacent bins to "smooth" each count into $200 ms$ bins. This is the same as convolving the counts with a length 4 kernel of ones. + +```{code-cell} ipython3 +smth_counts = counts.convolve(np.ones(4)) +``` + +To see this in action, we've provided an animation to visualize the convolution on a single unit. In the top figure, we'll see the original counts that have been binned in $50ms$ windows as well as the kernel representing the moving sum as it slides acorss the trial. When the kernel meets the binned counts, the convolution is equal to the integral of $kernel * counts$, or the sum of the shaded green area. The result of the convolution is in the bottom plot, a smoothed version of the counts in the top plot. + +```{code-cell} ipython3 +workshop_utils.animate_1d_convolution(counts.loc[177], np.ones(4), tsd_label="original counts", kernel_label="moving sum", conv_label="convolved counts") +``` + +Let's use `nap.decode_1d` again, but now with our smoothed counts in place of the raw spike times. Note that the bin size we'll want to provide the the larger bin size, $200ms$, since this is the true width of each bin. + +```{code-cell} ipython3 +smth_decoded_position, smth_decoded_prob = nap.decode_1d(place_fields, smth_counts, ex_run_ep, bin_size=0.2) +``` + +Let's plot the results. + +```{code-cell} ipython3 plt.subplots(figsize=(12, 4), constrained_layout=True) -plt.pcolormesh(p.index, p.columns, np.transpose(p)) -plt.plot(position.restrict(ex_run_ep), color="r") +plt.pcolormesh(smth_decoded_prob.index, smth_decoded_prob.columns, np.transpose(smth_decoded_prob)) +plt.plot(smth_decoded_position, color="green") +plt.plot(ex_position, color="r") plt.xlabel("Time (s)") plt.ylabel("Position (cm)") plt.colorbar(label = "predicted probability") ``` +We can see a much smoother estimate of position, as well as a smoother posterior probability of position. + +### Bonus: theta sequences +Units phase precessing together creates fast, spatial sequences around the animal's true position. We can reveal this by decoding at an even shorter time scale, which will appear as errors in the decoder. + +```{code-cell} ipython3 +counts = spikes.restrict(ex_run_ep).count(0.01) +smth_counts = counts.convolve(np.ones(4)) +smth_decoded_position, smth_decoded_prob = nap.decode_1d(place_fields, smth_counts, ex_run_ep, bin_size=0.04) + +fig, axs = plt.subplots(2, 1, figsize=(12, 4), constrained_layout=True, height_ratios=[3,1], sharex=True) +c = axs[0].pcolormesh(smth_decoded_prob.index, smth_decoded_prob.columns, np.transpose(smth_decoded_prob)) +axs[0].plot(smth_decoded_position, color="green") +axs[0].plot(ex_position, color="r") +# axs[0].set_xlabel("Time (s)") +axs[0].set_ylabel("Position (cm)") +fig.colorbar(c, label = "predicted probability") + +axs[1].plot(ex_lfp_run) +axs[1].plot(theta_band.restrict(ex_run_ep)) +axs[1].set_ylabel("LFP (a.u.)") + +fig.supxlabel("Time (s)") +``` + +The estimated position oscillates with cycles of theta, which are referred to as "theta sequences". Fully understanding this organization of place cells and its role in learning, memory, and planning is an active topic of research in Neuroscience! + :::{card} Authors ^^^ diff --git a/src/workshop_utils/plotting.py b/src/workshop_utils/plotting.py index 203e7c9..bfdb40c 100644 --- a/src/workshop_utils/plotting.py +++ b/src/workshop_utils/plotting.py @@ -171,7 +171,7 @@ def setup(self, figsize): ### top plot ### ax = axs[0] # this is fixed - ax.plot(self.index, self.tsd, label="original array") + ax.plot(self.index, self.tsd, label=self.tsd_label) # initial visible convolution output and top center line if kmid >= 0: From 8cdc0e5168a5249b14b1eb36c6ab462c4213e36e Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 13 Jan 2025 13:38:21 -0500 Subject: [PATCH 7/7] Update phase_precession.md --- docs/source/full/day1/phase_precession.md | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/docs/source/full/day1/phase_precession.md b/docs/source/full/day1/phase_precession.md index 3246611..da267e5 100644 --- a/docs/source/full/day1/phase_precession.md +++ b/docs/source/full/day1/phase_precession.md @@ -482,13 +482,6 @@ axs[1].set_ylabel("LFP (a.u.)") fig.supxlabel("Time (s)") ``` -The estimated position oscillates with cycles of theta, which are referred to as "theta sequences". Fully understanding this organization of place cells and its role in learning, memory, and planning is an active topic of research in Neuroscience! +The estimated position oscillates with cycles of theta, where each "sweep" is referred to as a "theta sequence". Fully understanding this organization of place cells and its role in learning, memory, and planning is an active topic of research in Neuroscience! -:::{card} -Authors -^^^ -Kipp Freud (https://kippfreud.com/) -Guillaume Viejo - -:::