Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for neo >= 0.9.0, elephant >= 0.9.0, python >= 3.8 #34

Draft
wants to merge 45 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d444d78
neo 0.9.0 compatibility changes
shashwatsridhar Apr 15, 2021
3d0a013
minor code cleanup
shashwatsridhar Apr 15, 2021
58a0b44
improve plot in ParameterInputDialog
shashwatsridhar Apr 15, 2021
96dfdf5
remove multiline plot due to pyqtgraph compatibility
shashwatsridhar Apr 15, 2021
4e4f4d7
dtype specification for vum initialization
shashwatsridhar Apr 15, 2021
622f3d3
update neo, elephant, python version requirements
shashwatsridhar Apr 15, 2021
0ead407
remove redundant print statement
shashwatsridhar Apr 15, 2021
ee300d0
simplify color retrieval
shashwatsridhar Apr 15, 2021
0dcd700
fix color retrieval for 3d views
shashwatsridhar Apr 16, 2021
3d652e4
fix color retrieval for 3d views
shashwatsridhar Apr 16, 2021
7458f18
fix unit display bug
shashwatsridhar Apr 16, 2021
43d2978
disable suggested plot highlight in 3d pca
shashwatsridhar Apr 16, 2021
96ffefb
move connector map loading function to selector_widget.py
shashwatsridhar Apr 23, 2021
fa308b5
set zero-based channel numbering
shashwatsridhar Apr 23, 2021
de62fe8
align virtual units view polygons with tick labels
shashwatsridhar Apr 23, 2021
2886cd4
remove redundant signal in 3d scatter plot item
shashwatsridhar Apr 23, 2021
982b677
minor change in initialization of internal var
shashwatsridhar Apr 23, 2021
05b3f32
fix bug in code logic
shashwatsridhar Apr 23, 2021
3223ddf
smarter way to determine wavelength across blocks
shashwatsridhar Apr 23, 2021
144559b
simplify data load for pkl files (no caching)
shashwatsridhar Apr 23, 2021
4ec976e
sort input blocks by data
shashwatsridhar Apr 23, 2021
ce0df2b
minor code cleanup
shashwatsridhar Apr 15, 2021
69d3de3
improve plot in ParameterInputDialog
shashwatsridhar Apr 15, 2021
bc4171c
remove multiline plot due to pyqtgraph compatibility
shashwatsridhar Apr 15, 2021
9137896
code cleanup of virtual_unit_map.py
shashwatsridhar Apr 15, 2021
0455bb6
remove redundant print statement
shashwatsridhar Apr 15, 2021
1ec8467
simplify color retrieval
shashwatsridhar Apr 15, 2021
d5325e5
fix color retrieval for 3d views
shashwatsridhar Apr 16, 2021
be386d6
fix color retrieval for 3d views
shashwatsridhar Apr 16, 2021
5adba30
disable suggested plot highlight in 3d pca
shashwatsridhar Apr 16, 2021
614a05c
align virtual units view polygons with tick labels
shashwatsridhar Apr 23, 2021
18993ae
remove redundant signal in 3d scatter plot item
shashwatsridhar Apr 23, 2021
f629b67
minor change in initialization of internal var
shashwatsridhar Apr 23, 2021
96bfc66
fix bug in code logic
shashwatsridhar Apr 23, 2021
6343269
correct version number update
shashwatsridhar May 14, 2021
b418da3
neo 0.9.0 compatibility changes
shashwatsridhar Apr 15, 2021
4f0f6dd
dtype specification for vum initialization
shashwatsridhar Apr 15, 2021
d9f2e30
update neo, elephant, python version requirements
shashwatsridhar Apr 15, 2021
179aaef
fix unit display bug
shashwatsridhar Apr 16, 2021
a8dcac8
move connector map loading function to selector_widget.py
shashwatsridhar Apr 23, 2021
e1abb53
set zero-based channel numbering
shashwatsridhar Apr 23, 2021
ebb7320
smarter way to determine wavelength across blocks
shashwatsridhar Apr 23, 2021
b220707
simplify data load for pkl files (no caching)
shashwatsridhar Apr 23, 2021
7b6a94f
sort input blocks by data
shashwatsridhar Apr 23, 2021
357696b
Merge branch 'enh/neo_elephant_support' of github.com:/INM-6/swan int…
shashwatsridhar Jun 15, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ def readme():
return f.read()


if version_info < (3, 3):
raise RuntimeError("Required: Python version > 3.3")
if version_info < (3, 8):
raise RuntimeError("Required: Python version > 3.8")

with open("swan/version.py") as fp:
d = {}
Expand All @@ -21,14 +21,15 @@ def readme():
'PyQt5',
'pyqtgraph',
'odml',
'elephant',
'elephant>=0.9.0',
'pyopengl',
'matplotlib',
'scikit-learn',
'psutil',
'scipy',
'pandas',
'colorcet',
'neo>=0.9.0',
]

setup(
Expand Down
154 changes: 98 additions & 56 deletions swan/automatic_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,68 +127,62 @@ def generate_feature_vectors(parent_class, feature_dictionary, additional_dictio
def get_session_ids(blocks):
session_ids = []
for b, block in enumerate(blocks):
units = block.channel_indexes[0].units
for u, unit in enumerate(units):
if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split():
session_ids.append(b)
units = block.groups
for unit in units:
session_ids.append(b)

return session_ids


def get_real_unit_ids(blocks):
unit_ids = []
for block in blocks:
units = block.channel_indexes[0].units
units = block.groups
for u, unit in enumerate(units):
if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split():
unit_ids.append(u)
unit_ids.append(u)

return unit_ids


def get_mean_waveforms(blocks):
mean_waveforms = []
for block in blocks:
units = block.channel_indexes[0].units
units = block.groups
for unit in units:
if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split():
waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :]
waves = waves - waves.mean(axis=1, keepdims=True)
mean_waveforms.append(waves.mean(axis=0))
waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :]
waves = waves - waves.mean(axis=1, keepdims=True)
mean_waveforms.append(waves.mean(axis=0))

return mean_waveforms


def get_all_waveforms(blocks):
waveforms = []
for block in blocks:
units = block.channel_indexes[0].units
units = block.groups
for unit in units:
if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split():
waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :]
waves = waves - waves.mean(axis=1, keepdims=True)
waveforms.append(waves)
waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :]
waves = waves - waves.mean(axis=1, keepdims=True)
waveforms.append(waves)
return waveforms


def get_all_spiketrains(blocks):
spiketrains = []
for block in blocks:
units = block.channel_indexes[0].units
units = block.groups
for unit in units:
if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split():
train = unit.spiketrains[0].times.magnitude
spiketrains.append(train)
train = unit.spiketrains[0].times.magnitude
spiketrains.append(train)
return spiketrains


def get_time_stamps(blocks):
all_time_stamps = []
for block in blocks:
units = block.channel_indexes[0].units
units = block.groups
for unit in units:
if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split():
all_time_stamps.append(block.rec_datetime)
all_time_stamps.append(block.rec_datetime)

corrected_time_stamps = []
for ts in all_time_stamps:
Expand Down Expand Up @@ -477,14 +471,28 @@ def __init__(self, parent=None, algorithm=None):
self.algorithm = algorithm

self.interface = ParameterInputDialogUI(self)
self.plot_widget = self.interface.elbow_curve_plot.pg_canvas

self.interface.clusters.textChanged.connect(self.update_plot)

self._solvers = None
self._clusters = None
self._solvers = []
self._clusters = []
self._inertias = []
self.chosen_solver = None

self.values_dict = {}
self.final_state = False

self.spot_size = 10
self.default_pen = pg.mkPen(color='w',
width=2,
style=QtCore.Qt.DotLine)
self.default_symbol_pen = pg.mkPen(color='b')
self.default_brush = pg.mkBrush((0, 255, 0, 128))
self.default_symbol = 'o'

self.current_clusters_value = 2

def on_confirm(self):
self.final_state = True
self.accept()
Expand All @@ -493,23 +501,69 @@ def on_cancel(self):
self.final_state = False
self.reject()

def update_plots(self):
size = 10
pen = pg.mkPen(color='b')
brush = pg.mkBrush((0, 255, 0, 128))
symbol = 'o'

solvers = []
clusters = []
inertias = []
# @QtCore.pyqtSlot(object, object)
# def on_plot_clicked(self, data_item, ev):
# if ev.button() == 1:
# x_pos = ev.pos().x()
# closest_val = self._get_closest_val(x_pos)
# if closest_val is not None:
# self.interface.clusters.setValue(closest_val)
#
# @staticmethod
# def _get_closest_val(input_value):
# lower = np.floor(input_value)
# lower_diff = input_value - lower
# upper = np.ceil(input_value)
# upper_diff = upper - input_value
#
# pprint({"lower": lower,
# "lower_diff": lower_diff,
# "upper": upper,
# "upper_diff": upper_diff})
#
# if lower_diff < upper_diff and lower_diff < 0.1:
# return lower
# elif upper_diff < lower_diff and upper_diff < 0.1:
# return upper
# else:
# return None

@QtCore.pyqtSlot()
def update_plot(self):
clusters = self._clusters
inertias = self._inertias
sizes = []
pens = []
brushes = []
symbols = []

self.current_clusters_value = int(self.interface.clusters.value())

for cluster in clusters:
sizes.append(self.spot_size)
symbols.append(self.default_symbol)
if cluster == self.current_clusters_value:
pens.append(pg.mkPen('r'))
brushes.append(pg.mkBrush((255, 0, 0, 128)))
else:
pens.append(self.default_symbol_pen)
brushes.append(self.default_brush)

self._plot(clusters, inertias,
symbolPen=pens,
symbolSize=sizes,
symbol=symbols,
symbolBrush=brushes,
pxMode=True,
clear=True)

def plot(self):
solvers = []
clusters = []
inertias = []

QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.WaitCursor)
self.update_values_dict()
current_clusters_values = self.interface.clusters.value()
feature_vectors = generate_feature_vectors(self.algorithm, self.values_dict, self.algorithm.additional_dict)

with Pool(processes=None) as pool:
Expand All @@ -522,32 +576,20 @@ def update_plots(self):
solvers.append(solver)
clusters.append(n_clusters)
inertias.append(solver.inertia_)
sizes.append(size)
symbols.append(symbol)

if n_clusters == int(current_clusters_values):
pens.append(pg.mkPen('r'))
brushes.append(pg.mkBrush((255, 0, 0, 128)))
else:
pens.append(pen)
brushes.append(brush)

self._solvers = solvers
self._clusters = clusters
self._inertias = inertias

self.interface.elbow_curve_plot.pg_canvas.plotItem.plot(clusters, inertias,
pen=pg.mkPen(color='w',
width=2,
style=QtCore.Qt.DotLine),
symbolPen=pens,
symbolSize=sizes,
symbol=symbols,
symbolBrush=brushes,
pxMode=True,
clear=True)
self.update_plot()

QtWidgets.QApplication.restoreOverrideCursor()

def _plot(self, x, y, *args, **kwargs):
self.interface.elbow_curve_plot.pg_canvas.clear()
plot_item = self.interface.elbow_curve_plot.pg_canvas.plot(x, y, *args, **kwargs)
# plot_item.sigClicked.connect(self.on_plot_clicked)

def update_values_dict(self):
output_dict = {}
for key in self.interface.key_names:
Expand All @@ -563,5 +605,5 @@ def exec_(self):
QtWidgets.QDialog.exec_(self)
self.update_values_dict()
if self._solvers is not None:
self.chosen_solver = self._solvers[self._clusters.index(self.interface.clusters.value())]
self.chosen_solver = self._solvers[self._clusters.index(self.current_clusters_value)]
return self.values_dict, self.chosen_solver, self.final_state
Loading