The main purpose of this library is to create a predictive model for retinal ganglion cell activity: given stimulus and spike history, predict the next 100ms or so of spike activity. A secondary objective is to provide an efficient and easy to use API for interacting with data collected from multi-electrode array (MEA) data.
pip install retinapy
spikeprediction.py
and models.py
define neural network models and
training objectives used for the prediction task. dataset.py
turns the
mea data into PyTorch datasets for consumption by the training loop.
Currently, The model is trained and tested on MEA recordings of chicken retina exposed to a full-field noise stimulus, collected by Marvin Seifert (https://doi.org/10.1101). The model is trained to predict the individual output of >1000 "cells" collected over 18 recordings. Cells in quotes as they are the puported cells identified by the spike sorting algorithm.
Spike prediction for one cell are shown below, for about 5 seconds of test data:
spike_prediction.mp4
Below, for the same cell, predicted and ground truth spikes are counted in 100 ms bins. The data is for ~86 seconds of test data, without smoothing over time or averaging over multiple trials.
(A line chart probably isn't so appropriate here, but it makes a visual comparison easier compared to just using points)
The model is trained once for all cells. (recording_id, cell_id)
tuples are encoded via a variational auto-encoder. The aim here is that in the future, an additional network can be used to place unknown cells into the embedding space so as to be able to do spike prediction for additional cells from new recordings. A consequence of this approach is that the encoding space can be inspected to see if any interesting clustering has emerged. Below is a screenshot of a t-SNE plot of the latent space. On inspection, the STA kernels for nearby points are similar.
Some reusable Pytorch modules live in nn.py
.
train.py
and _logging.py
contain general purpose training infrastructure
in PyTorch, focused around a training loop. It supports a slimmed down
feature set of what you might get from a library such as FastAI or Pytorch
Lightning.
If you aren't training deep learning models, you might still find some of the
functionality in the retinapy.mea
module useful. It handles loading MEA data
and provides some useful functions such as downsampling and data splitting.
Look no further if you just want to extract spike snippets for
spike-triggered-averaging.
>>> import retinapy.mea as mea
>>> rec = mea.single_3brain_recording(
... rec_name="Chicken_17_08_21_Phase_00",
... data_dir="./data/ff_noise_recordings")
>>> # Extract spike windows.
>>> # I'm downsampling by 18, and this gives ~1000 timestep per second (992).
>>> downsample = 18
>>> rec = mea.decompress_recording(rec, downsample)
>>> snippets = rec.spike_snippets(total_len=800, post_spike_len=200)
>>> print(len(snippets))
154
>>> print(snippets[0].shape)
(17993, 800, 4)
Various plotting functions are collected in vis.py
. Plotly is the main
plotting library being used in this module.
For a bit of fun, there is a visualization tool in /snippet_viewer
which
inspects cells by viewing all of the snippets that contribute to the cell's
STA kernel. Below is a video showing 500 ms leading up to every spike of a
given cell.
A hosted version is at: https://mea.bio/
demo2.mp4
The stimulus is 50-50 on-off 4-channel color noise. Each square shows the time pregression of the stimulus shown to the cell leading up to a spike. Our stimulator has 4 LEDs. In the visualization, the intensity of the 3 LEDs most similar to the red, green and blue sRGB primaries are mapped to the sRGB color values of the inner square, and and the intensity of the 4th LED (UV) is visualized as a purple boarder.