diff --git a/setup.py b/setup.py index 58c4e97..7d48ab2 100644 --- a/setup.py +++ b/setup.py @@ -7,8 +7,8 @@ def readme(): return f.read() -if version_info < (3, 3): - raise RuntimeError("Required: Python version > 3.3") +if version_info < (3, 8): + raise RuntimeError("Required: Python version > 3.8") with open("swan/version.py") as fp: d = {} @@ -21,7 +21,7 @@ def readme(): 'PyQt5', 'pyqtgraph', 'odml', - 'elephant', + 'elephant>=0.9.0', 'pyopengl', 'matplotlib', 'scikit-learn', @@ -29,6 +29,7 @@ def readme(): 'scipy', 'pandas', 'colorcet', + 'neo>=0.9.0', ] setup( diff --git a/swan/automatic_mapping.py b/swan/automatic_mapping.py index 1651c8f..a28f36d 100644 --- a/swan/automatic_mapping.py +++ b/swan/automatic_mapping.py @@ -127,10 +127,9 @@ def generate_feature_vectors(parent_class, feature_dictionary, additional_dictio def get_session_ids(blocks): session_ids = [] for b, block in enumerate(blocks): - units = block.channel_indexes[0].units - for u, unit in enumerate(units): - if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split(): - session_ids.append(b) + units = block.groups + for unit in units: + session_ids.append(b) return session_ids @@ -138,10 +137,9 @@ def get_session_ids(blocks): def get_real_unit_ids(blocks): unit_ids = [] for block in blocks: - units = block.channel_indexes[0].units + units = block.groups for u, unit in enumerate(units): - if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split(): - unit_ids.append(u) + unit_ids.append(u) return unit_ids @@ -149,12 +147,11 @@ def get_real_unit_ids(blocks): def get_mean_waveforms(blocks): mean_waveforms = [] for block in blocks: - units = block.channel_indexes[0].units + units = block.groups for unit in units: - if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split(): - waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :] - waves = waves - waves.mean(axis=1, keepdims=True) - mean_waveforms.append(waves.mean(axis=0)) + waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :] + waves = waves - waves.mean(axis=1, keepdims=True) + mean_waveforms.append(waves.mean(axis=0)) return mean_waveforms @@ -162,33 +159,30 @@ def get_mean_waveforms(blocks): def get_all_waveforms(blocks): waveforms = [] for block in blocks: - units = block.channel_indexes[0].units + units = block.groups for unit in units: - if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split(): - waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :] - waves = waves - waves.mean(axis=1, keepdims=True) - waveforms.append(waves) + waves = unit.spiketrains[0].waveforms.magnitude[:, 0, :] + waves = waves - waves.mean(axis=1, keepdims=True) + waveforms.append(waves) return waveforms def get_all_spiketrains(blocks): spiketrains = [] for block in blocks: - units = block.channel_indexes[0].units + units = block.groups for unit in units: - if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split(): - train = unit.spiketrains[0].times.magnitude - spiketrains.append(train) + train = unit.spiketrains[0].times.magnitude + spiketrains.append(train) return spiketrains def get_time_stamps(blocks): all_time_stamps = [] for block in blocks: - units = block.channel_indexes[0].units + units = block.groups for unit in units: - if 'noise' not in unit.description.split() and 'unclassified' not in unit.description.split(): - all_time_stamps.append(block.rec_datetime) + all_time_stamps.append(block.rec_datetime) corrected_time_stamps = [] for ts in all_time_stamps: @@ -477,14 +471,28 @@ def __init__(self, parent=None, algorithm=None): self.algorithm = algorithm self.interface = ParameterInputDialogUI(self) + self.plot_widget = self.interface.elbow_curve_plot.pg_canvas + + self.interface.clusters.textChanged.connect(self.update_plot) - self._solvers = None - self._clusters = None + self._solvers = [] + self._clusters = [] + self._inertias = [] self.chosen_solver = None self.values_dict = {} self.final_state = False + self.spot_size = 10 + self.default_pen = pg.mkPen(color='w', + width=2, + style=QtCore.Qt.DotLine) + self.default_symbol_pen = pg.mkPen(color='b') + self.default_brush = pg.mkBrush((0, 255, 0, 128)) + self.default_symbol = 'o' + + self.current_clusters_value = 2 + def on_confirm(self): self.final_state = True self.accept() @@ -493,23 +501,69 @@ def on_cancel(self): self.final_state = False self.reject() - def update_plots(self): - size = 10 - pen = pg.mkPen(color='b') - brush = pg.mkBrush((0, 255, 0, 128)) - symbol = 'o' - - solvers = [] - clusters = [] - inertias = [] + # @QtCore.pyqtSlot(object, object) + # def on_plot_clicked(self, data_item, ev): + # if ev.button() == 1: + # x_pos = ev.pos().x() + # closest_val = self._get_closest_val(x_pos) + # if closest_val is not None: + # self.interface.clusters.setValue(closest_val) + # + # @staticmethod + # def _get_closest_val(input_value): + # lower = np.floor(input_value) + # lower_diff = input_value - lower + # upper = np.ceil(input_value) + # upper_diff = upper - input_value + # + # pprint({"lower": lower, + # "lower_diff": lower_diff, + # "upper": upper, + # "upper_diff": upper_diff}) + # + # if lower_diff < upper_diff and lower_diff < 0.1: + # return lower + # elif upper_diff < lower_diff and upper_diff < 0.1: + # return upper + # else: + # return None + + @QtCore.pyqtSlot() + def update_plot(self): + clusters = self._clusters + inertias = self._inertias sizes = [] pens = [] brushes = [] symbols = [] + self.current_clusters_value = int(self.interface.clusters.value()) + + for cluster in clusters: + sizes.append(self.spot_size) + symbols.append(self.default_symbol) + if cluster == self.current_clusters_value: + pens.append(pg.mkPen('r')) + brushes.append(pg.mkBrush((255, 0, 0, 128))) + else: + pens.append(self.default_symbol_pen) + brushes.append(self.default_brush) + + self._plot(clusters, inertias, + symbolPen=pens, + symbolSize=sizes, + symbol=symbols, + symbolBrush=brushes, + pxMode=True, + clear=True) + + def plot(self): + solvers = [] + clusters = [] + inertias = [] + QtWidgets.QApplication.setOverrideCursor(QtCore.Qt.WaitCursor) self.update_values_dict() - current_clusters_values = self.interface.clusters.value() feature_vectors = generate_feature_vectors(self.algorithm, self.values_dict, self.algorithm.additional_dict) with Pool(processes=None) as pool: @@ -522,32 +576,20 @@ def update_plots(self): solvers.append(solver) clusters.append(n_clusters) inertias.append(solver.inertia_) - sizes.append(size) - symbols.append(symbol) - - if n_clusters == int(current_clusters_values): - pens.append(pg.mkPen('r')) - brushes.append(pg.mkBrush((255, 0, 0, 128))) - else: - pens.append(pen) - brushes.append(brush) self._solvers = solvers self._clusters = clusters + self._inertias = inertias - self.interface.elbow_curve_plot.pg_canvas.plotItem.plot(clusters, inertias, - pen=pg.mkPen(color='w', - width=2, - style=QtCore.Qt.DotLine), - symbolPen=pens, - symbolSize=sizes, - symbol=symbols, - symbolBrush=brushes, - pxMode=True, - clear=True) + self.update_plot() QtWidgets.QApplication.restoreOverrideCursor() + def _plot(self, x, y, *args, **kwargs): + self.interface.elbow_curve_plot.pg_canvas.clear() + plot_item = self.interface.elbow_curve_plot.pg_canvas.plot(x, y, *args, **kwargs) + # plot_item.sigClicked.connect(self.on_plot_clicked) + def update_values_dict(self): output_dict = {} for key in self.interface.key_names: @@ -563,5 +605,5 @@ def exec_(self): QtWidgets.QDialog.exec_(self) self.update_values_dict() if self._solvers is not None: - self.chosen_solver = self._solvers[self._clusters.index(self.interface.clusters.value())] + self.chosen_solver = self._solvers[self._clusters.index(self.current_clusters_value)] return self.values_dict, self.chosen_solver, self.final_state diff --git a/swan/gui/file_dialog_ui.py b/swan/gui/file_dialog_ui.py index abbb182..a59c378 100644 --- a/swan/gui/file_dialog_ui.py +++ b/swan/gui/file_dialog_ui.py @@ -7,94 +7,64 @@ # # WARNING! All changes made in this file will be lost! -from PyQt5 import QtCore, QtGui +from PyQt5 import QtWidgets -try: - _fromUtf8 = QtCore.QString.fromUtf8 -except AttributeError: - _fromUtf8 = lambda s: s +class FileDialogUI(object): -class Ui_File_Dialog(object): - def setupUi(self, file_dialog): - file_dialog.setObjectName(_fromUtf8("file_dialog")) + def __init__(self, file_dialog): file_dialog.resize(557, 605) - self.verticalLayout = QtGui.QVBoxLayout(file_dialog) - self.verticalLayout.setObjectName(_fromUtf8("verticalLayout")) - self.groupBox = QtGui.QGroupBox(file_dialog) - self.groupBox.setObjectName(_fromUtf8("groupBox")) - self.horizontalLayout = QtGui.QHBoxLayout(self.groupBox) - self.horizontalLayout.setObjectName(_fromUtf8("horizontalLayout")) - self.label = QtGui.QLabel(self.groupBox) - self.label.setObjectName(_fromUtf8("label")) + self.verticalLayout = QtWidgets.QVBoxLayout(file_dialog) + + self.groupBox = QtWidgets.QGroupBox("Select the parent folder of your data", file_dialog) + + self.horizontalLayout = QtWidgets.QHBoxLayout(self.groupBox) + + self.label = QtWidgets.QLabel("Path:") self.horizontalLayout.addWidget(self.label) - self.pathEdit = QtGui.QLineEdit(self.groupBox) + + self.pathEdit = QtWidgets.QLineEdit(self.groupBox) self.pathEdit.setReadOnly(True) - self.pathEdit.setObjectName(_fromUtf8("pathEdit")) self.horizontalLayout.addWidget(self.pathEdit) - self.pathBtn = QtGui.QPushButton(self.groupBox) - self.pathBtn.setObjectName(_fromUtf8("pathBtn")) + + self.pathBtn = QtWidgets.QPushButton("Browse...", self.groupBox) self.horizontalLayout.addWidget(self.pathBtn) + self.verticalLayout.addWidget(self.groupBox) - self.groupBox_2 = QtGui.QGroupBox(file_dialog) - self.groupBox_2.setObjectName(_fromUtf8("groupBox_2")) + + self.groupBox_2 = QtWidgets.QGroupBox("Choose files and add them to (or remove them from) " + "the list on the right", + file_dialog) self.groupBox_2.setStyleSheet('QGroupBox:title {' 'text-align: center;' 'subcontrol-origin: content;' 'subcontrol-position: top center; }') - self.horizontalLayout_2 = QtGui.QHBoxLayout(self.groupBox_2) - self.horizontalLayout_2.setObjectName(_fromUtf8("horizontalLayout_2")) - self.selectList = QtGui.QListWidget(self.groupBox_2) - self.selectList.setObjectName(_fromUtf8("selectList")) + + self.horizontalLayout_2 = QtWidgets.QHBoxLayout(self.groupBox_2) + + self.selectList = QtWidgets.QListWidget(self.groupBox_2) self.horizontalLayout_2.addWidget(self.selectList) - self.verticalLayout_2 = QtGui.QVBoxLayout() - self.verticalLayout_2.setObjectName(_fromUtf8("verticalLayout_2")) - self.addBtn = QtGui.QPushButton(self.groupBox_2) - self.addBtn.setObjectName(_fromUtf8("addBtn")) + + self.verticalLayout_2 = QtWidgets.QVBoxLayout() + + self.addBtn = QtWidgets.QPushButton("Add", self.groupBox_2) self.verticalLayout_2.addWidget(self.addBtn) - self.removeBtn = QtGui.QPushButton(self.groupBox_2) - self.removeBtn.setObjectName(_fromUtf8("removeBtn")) - self.verticalLayout_2.addWidget(self.removeBtn) - self.horizontalLayout_2.addLayout(self.verticalLayout_2) - self.selectionList = QtGui.QListWidget(self.groupBox_2) - self.selectionList.setObjectName(_fromUtf8("selectionList")) - self.horizontalLayout_2.addWidget(self.selectionList) - self.verticalLayout.addWidget(self.groupBox_2) - self.btnBox = QtGui.QDialogButtonBox(file_dialog) - self.btnBox.setStandardButtons(QtGui.QDialogButtonBox.Cancel | QtGui.QDialogButtonBox.Ok) - self.btnBox.setObjectName(_fromUtf8("btnBox")) - self.verticalLayout.addWidget(self.btnBox) - self.label.setBuddy(self.pathEdit) - self.retranslateUi(file_dialog) - QtCore.QMetaObject.connectSlotsByName(file_dialog) + self.removeBtn = QtWidgets.QPushButton("Remove", self.groupBox_2) + self.verticalLayout_2.addWidget(self.removeBtn) - def retranslateUi(self, file_dialog): - file_dialog.setWindowTitle(QtGui.QApplication.translate("File_Dialog", - "File selection", - None)) + self.horizontalLayout_2.addLayout(self.verticalLayout_2) - self.groupBox.setTitle(QtGui.QApplication.translate("File_Dialog", - "Select the parent folder of your data", - None)) + self.selectionList = QtWidgets.QListWidget(self.groupBox_2) + self.horizontalLayout_2.addWidget(self.selectionList) - self.label.setText(QtGui.QApplication.translate("File_Dialog", - "Path:", - None)) + self.verticalLayout.addWidget(self.groupBox_2) - self.pathBtn.setText(QtGui.QApplication.translate("File_Dialog", - "Browse...", - None)) + self.btnBox = QtWidgets.QDialogButtonBox(file_dialog) + self.btnBox.setStandardButtons(QtWidgets.QDialogButtonBox.Cancel | QtWidgets.QDialogButtonBox.Ok) - self.groupBox_2.setTitle(QtGui.QApplication.translate("File_Dialog", - "Choose files and add them to (or remove them from) the" - " list on the right", - None)) + self.verticalLayout.addWidget(self.btnBox) - self.addBtn.setText(QtGui.QApplication.translate("File_Dialog", - "Add", - None)) + self.label.setBuddy(self.pathEdit) - self.removeBtn.setText(QtGui.QApplication.translate("File_Dialog", - "Remove", - None)) + file_dialog.setWindowTitle("File selection") diff --git a/swan/gui/parameter_input_dialog_ui.py b/swan/gui/parameter_input_dialog_ui.py index 3bc3311..6404c15 100644 --- a/swan/gui/parameter_input_dialog_ui.py +++ b/swan/gui/parameter_input_dialog_ui.py @@ -123,9 +123,9 @@ def __init__(self, parent_dialog): self.main_layout.addWidget(self.cancel_button, offset_height + 3, 0, 1, 1) self.cancel_button.clicked.connect(parent_dialog.on_cancel) - self.update_plot_button = QtWidgets.QPushButton("Update Plots") + self.update_plot_button = QtWidgets.QPushButton("Update Plot") self.main_layout.addWidget(self.update_plot_button, offset_height + 3, 2, 1, 1) - self.update_plot_button.clicked.connect(parent_dialog.update_plots) + self.update_plot_button.clicked.connect(parent_dialog.plot) self.confirm_button = QtWidgets.QPushButton("Calculate") self.main_layout.addWidget(self.confirm_button, offset_height + 3, 3, 1, 1) diff --git a/swan/main.py b/swan/main.py index aa8a12c..522e9d6 100644 --- a/swan/main.py +++ b/swan/main.py @@ -14,7 +14,6 @@ """ # system imports from os.path import basename, split, join, exists -import csv import webbrowser as web from pyqtgraph.Qt import QtCore, QtGui, QtWidgets import os @@ -25,7 +24,7 @@ # swan-specific imports from swan import about, title from swan.gui.main_ui import MainUI -from swan.widgets.file_dialog import File_Dialog +from swan.widgets.file_dialog import FileDialog from swan.widgets.preferences_dialog import Preferences_Dialog from swan.storage import MyStorage from swan.views.virtual_units_view import VirtualUnitsView @@ -225,7 +224,7 @@ def on_action_new_project_triggered(self): """ This method is called if you click on *File->New Project*. - Shows a :class:`src.file_dialog.File_Dialog` to choose files and after accepting it creates a + Shows a :class:`src.file_dialog.FileDialog` to choose files and after accepting it creates a new project. The project consists of two files. One is a .txt file which contains the data file paths and the other one is a .vum file which contains the :class:`src.virtualunitmap.VirtualUnitMap`. @@ -241,7 +240,7 @@ def on_action_new_project_triggered(self): """ if self.dirty_project(): - dia = File_Dialog() + dia = FileDialog() if dia.exec_(): files = dia.get_files() @@ -349,7 +348,7 @@ def on_action_load_connector_map_triggered(self): directory=self._prodir, options=dialog_options) try: - self.load_connector_map(filename) + self.selector.load_connector_map(filename, self._my_storage.get_channel()) except ValueError: QtWidgets.QMessageBox.critical(None, "Loading error", "The connector map could not be loaded!") @@ -896,47 +895,6 @@ def check_cache(self): # mkdir(self._preferences["cacheDir"]) pathlib.Path(self._preferences["cacheDir"]).mkdir(parents=True, exist_ok=True) - def load_connector_map(self, filename): - """ - Loads a connector map given as a .csv file. - - The file has to contain two columns. The first will be ignored but must exist - (e.g. the numbers 1-100) and the other one has to contain the mapped channel numbers. - Choose **,** as delimiter. - - **Arguments** - - *filename* (string): - The csv file to load. - - **Raises**: :class:`ValueError` - If the connector map could not be loaded. - - """ - if filename: - delimiter = ',' - try: - with open(filename, "r") as fn: - channel_list = [] - reader = csv.reader(fn, delimiter=delimiter) - for row in reader: - # just read the second column - channel_list.append(int(row[1])) - channels = self.selector.get_dirty_channels() - - # overwrite existing mapping - self.selector.set_channels(channel_list) - self.selector.reset_sel() - self.selector.reset_dirty() - - # the dirty channels and the selected one has to be set again - for channel in channels: - self.selector.set_dirty(channel, True) - - self.selector.select_only(self._my_storage.get_channel()) - except Exception as e: - print(e) - def load_preferences(self): """ Loads the preferences from the preferences file. diff --git a/swan/neodata.py b/swan/neodata.py index 2d5967e..f2d68ed 100644 --- a/swan/neodata.py +++ b/swan/neodata.py @@ -12,8 +12,9 @@ """ import gc from itertools import chain -from neo.io.blackrockio_v4 import BlackrockIO from neo.io.pickleio import PickleIO +from neo.io import get_io +from neo import Group, SpikeTrain import numpy as np from numpy.linalg import norm from os.path import join, split, exists @@ -22,6 +23,17 @@ from scipy.signal import filtfilt, butter +def unit_in_channel(unit, channel): + channel_id = unit.annotations.get("channel_id", None) + if channel_id is not None: + if channel_id == channel: + return True + else: + return False + else: + return False + + class NeoData(QObject): """ This class makes it possible to load and manage neo data. @@ -57,30 +69,17 @@ def __init__(self, cache_dir): """ super(QObject, self).__init__() - # properties{ self.cdir = cache_dir self.blocks = [] self.total_units_per_block = [] self.rgios = [] - self._wave_length = 0. + self._wave_length = 0 self.segments = [] self.units = [] self.events = [] self.unique_labels = [] self.sampling_rate = 0. - # } - - #### general methods #### - - # def load_rgIOs(self, files): - # l = len(files) - # step = int(50/l) - # if not self.rgios: - # for i, f in enumerate(files): - # rgIO = BlackrockIO(f) - # self.rgios.append(rgIO) - # self.progress.emit(step*(i+1)) - # return self.rgios + self.current_channel = 0 def load(self, files, channel): """ @@ -99,62 +98,84 @@ def load(self, files, channel): """ # information for the progress - l = len(files) - # if not self.rgios: - # v = 50 - # step = int(50/l) - # else: - # v = 0 - # step = int(100/l) - v = 0 - step = int(100 / l) + num_of_files = len(files) + + count = 0 + step = int(100 / num_of_files) self.delete_blocks() blocks = [] # loading the blocks for i, f in enumerate(files): - name = join(self.cdir, split(f)[1] + "_" + str(channel) + ".pkl") - - if exists(name): - # load from cache - pIO = PickleIO(name) - block = pIO.read_block() - else: - # loading - session = BlackrockIO(f) - block = session.read_block(index=None, name=None, description=None, nsx_to_load='none', - n_starts=None, n_stops=None, channels=channel, units='all', - load_waveforms=True, load_events=True, lazy=False, cascade=True) - del session - - # caching - pIO = PickleIO(name) - pIO.write_block(block) - + # name = join(self.cdir, split(f)[1] + "_" + str(channel) + ".pkl") + # + # if exists(name): + # # load from cache + # pIO = PickleIO(name) + # block = pIO.read_block() + # else: + + session = get_io(f) + block = session.read_block() blocks.append(block) + + del session + # emits a signal with the current progress # after loading a block - self.progress.emit(v + step * (i + 1)) - - self.blocks = blocks - self.segments = [block.segments for block in self.blocks] - nums = [len([unit for unit in b.channel_indexes[0].units - if "noise" not in unit.description.split() - and "unclassified" not in unit.description.split()]) - for b in self.blocks] - self.units = [[unit for unit in b.channel_indexes[0].units - if "noise" not in unit.description.split() - and "unclassified" not in unit.description.split()] - for b in self.blocks] - # self.spiketrains = self.create_spiketrains_dictionary(self.units) + self.progress.emit(count + step * (i + 1)) + + blocks = sorted(blocks, key=lambda x: x.rec_datetime) + + self.blocks = [] + self.segments = [] + self.units = [] + nums = [] + for b, block in enumerate(blocks): + self.units.append([]) + for s, segment in enumerate(block.segments): + + # count spiketrains and save units as neo.Group objects + num_spiketrains = 0 + for spiketrain in segment.spiketrains: + if spiketrain.annotations["channel_id"] == channel and len(spiketrain) > 2: + unit = Group( + objects=[spiketrain], + name=spiketrain.name, + description=f"Unit channel_id: {spiketrain.annotations['channel_id']}, " + f"unit_id: {spiketrain.annotations['unit_id']}", + file_origin=spiketrain.file_origin, + allowed_types=[SpikeTrain], + **spiketrain.annotations, + ) + self.units[b].append(unit) + num_spiketrains += 1 + + nums.append(num_spiketrains) + + block.groups = sorted(self.units[b], key=lambda x: int(x.annotations["unit_id"])) + self.segments.append(block.segments) + self.blocks.append(block) + self.set_events_and_labels() self.total_units_per_block = nums - self._wave_length = len(self.blocks[0].channel_indexes[0].units[0].spiketrains[0].waveforms[0].magnitude[0]) - # TODO: Loop over all sessions to find the first session which has a unit with waveforms + waveform_sizes = [] + for block in blocks: + for unit in block.groups: + waveform_sizes.append(unit.spiketrains[0].waveforms.shape[-1]) - self.sampling_rate = self.blocks[0].channel_indexes[0].units[0].spiketrains[0].sampling_rate + assert np.unique(waveform_sizes).size == 1, "spike_widths across blocks must be equal" + + self._wave_length = np.unique(waveform_sizes)[0] + + try: + self.sampling_rate = self.blocks[0].annotations["sampling_rate"] + except (KeyError, IndexError): + self.sampling_rate = 30000. * pq.Hz + + self.current_channel = channel def get_data(self, layer, unit, **kwargs): """ @@ -240,9 +261,10 @@ def get_yscale(self, layer="average"): yranges0 = [] yranges1 = [] for block in self.blocks: - for unit in block.channel_indexes[0].units: - if "noise" not in unit.description.split() and "unclassified" not in unit.description.split(): + for unit in block.groups: + if unit_in_channel(unit, self.current_channel): datas.append(self.get_data(layer, unit)) + for data in datas: tmp0 = np.min(data) tmp1 = np.max(data) diff --git a/swan/storage.py b/swan/storage.py index dc7fe51..d5eda20 100644 --- a/swan/storage.py +++ b/swan/storage.py @@ -118,8 +118,8 @@ def __init__(self, program_dir, cache_dir): self._loading = False # } - self.store("channel", 1) - self.store("lastchannel", 1) + self.store("channel", 0) + self.store("lastchannel", 0) def set_cache_dir(self, cache_dir): """ @@ -483,10 +483,10 @@ def recalculate(self, mapping=0, parent=None): """ vum = self.get_map() - print("VUM before recalculation: {}".format(np.shape(vum.mapping))) data = self.get_data() + vum.calculate_mapping(data, self, automatic_mapping=mapping, parent=parent) - print("VUM after recalculation: {}".format(np.shape(vum.mapping))) + self.change_map() def revert(self): diff --git a/swan/version.py b/swan/version.py index 7356355..be64af3 100644 --- a/swan/version.py +++ b/swan/version.py @@ -1,2 +1,2 @@ # -*- coding: utf-8 -*- -version = '0.1' +version = '0.1.0' diff --git a/swan/views/isi_histograms_view.py b/swan/views/isi_histograms_view.py index f730715..dc609b8 100644 --- a/swan/views/isi_histograms_view.py +++ b/swan/views/isi_histograms_view.py @@ -239,7 +239,7 @@ def do_plot(self, vum, data): runit = vum.get_realunit(session, unit_id, data) d = data.get_data("sessions", runit) intervals[unit_id].extend(d) - col = vum.get_colour(unit_id, False, layer, False) + col = vum.get_colour(unit_id) self.datas[unit_id] = [np.sort(intervals[unit_id]), col, unit_id, session, clickable] if intervals: @@ -263,7 +263,7 @@ def do_plot(self, vum, data): if active[session][unit_id]: runit = vum.get_realunit(session, unit_id, data) datas = data.get_data("units", runit) - col = vum.get_colour(unit_id, False, layer, False) + col = vum.get_colour(unit_id) clickable = True self.datas["{}{}".format(session, unit_id)] = [datas, col, unit_id, session, clickable] for d in datas: diff --git a/swan/views/mean_waveforms_view.py b/swan/views/mean_waveforms_view.py index a5f1a9e..032ae09 100644 --- a/swan/views/mean_waveforms_view.py +++ b/swan/views/mean_waveforms_view.py @@ -114,7 +114,7 @@ def do_plot(self, vum, data): if layer == "standard deviation": runit = vum.get_realunit(session, unit_id, data) datas = data.get_data(layer, runit) - col = vum.get_colour(unit_id, False, layer, False) + col = vum.get_colour(unit_id) xs = arange(data.get_wave_length()) * 1 / data.sampling_rate.magnitude ys = datas.rescale(V) self.plot_std(xs=xs, ys=ys, color=col) @@ -122,7 +122,7 @@ def do_plot(self, vum, data): elif layer == "average": runit = vum.get_realunit(session, unit_id, data) datas = data.get_data(layer, runit) - col = vum.get_colour(unit_id, False, layer, False) + col = vum.get_colour(unit_id) x = arange(data.get_wave_length()) * 1 / data.sampling_rate.magnitude y = datas.rescale(V) self.plot_mean(x=x, y=y, color=col, unit_id=unit_id, session=session, clickable=True) diff --git a/swan/views/pca_2d_view.py b/swan/views/pca_2d_view.py index dea375c..20bd601 100644 --- a/swan/views/pca_2d_view.py +++ b/swan/views/pca_2d_view.py @@ -119,7 +119,7 @@ def do_plot(self, vum, data): c = 0 for u in range(len(active[i])): if active[i][u]: - col = vum.get_colour(u, False, layer, False) + col = vum.get_colour(u) self.plotPoints(pos=pca_channel[c], size=1, color=col, name="".format(i, u)) self.plotMean(x=mn(pca_channel[c][:, 0], axis=0), y=mn(pca_channel[c][:, 1], axis=0), size=15, color=col, @@ -137,7 +137,7 @@ def do_plot(self, vum, data): c = 0 for u in range(len(active[dom])): if active[dom][u]: - col = vum.get_colour(u, False, layer, False) + col = vum.get_colour(u) self.plotPoints(pos=dom_ch_pca[c], size=1, color=col, name="".format(i, u)) self.plotMean(x=mn(dom_ch_pca[c][:, 0], axis=0), y=mn(dom_ch_pca[c][:, 1], axis=0), size=15, color=col, diff --git a/swan/views/pca_3d_view.py b/swan/views/pca_3d_view.py index 04f9c3e..2b1462e 100644 --- a/swan/views/pca_3d_view.py +++ b/swan/views/pca_3d_view.py @@ -28,6 +28,8 @@ def __init__(self, parent=None): self.positions = [] self.means = [] + self.fill_alpha = 0.9 + self.pg_canvas.set_clickable(True) self.max_distance = 0 @@ -43,8 +45,8 @@ def clear_plot(self): def connect_means(self): self.pg_canvas.set_means(self.means) - for plot in self.pg_canvas.means: - plot.sig_clicked.connect(self.get_item) + # for plot in self.pg_canvas.means: + # plot.sig_clicked.connect(self.get_item) def do_plot(self, vum, data): self.save_camera_position() @@ -70,8 +72,8 @@ def do_plot(self, vum, data): dom_session = [] for unit_index in range(len(active[dom])): - runit = vum.get_realunit(dom, unit_index, data) if active[dom][unit_index]: + runit = vum.get_realunit(dom, unit_index, data) dom_session.append(data.get_data("all", runit)) m_dom_session, lv_dom_session = self.merge_session(dom_session) @@ -86,8 +88,8 @@ def do_plot(self, vum, data): if session_index != dom: session = [] for unit_index in range(len(active[session_index])): - runit = vum.get_realunit(session_index, unit_index, data) if active[session_index][unit_index]: + runit = vum.get_realunit(session_index, unit_index, data) session.append(data.get_data("all", runit)) merged_session, len_vec = self.merge_session(session) @@ -101,7 +103,8 @@ def do_plot(self, vum, data): c = 0 for unit_index in range(len(active[session_index])): if active[session_index][unit_index]: - col = vum.get_colour(unit_index, False, None, True) + col = vum.get_colour(unit_index) + col = tuple(val / 255. for val in col) + (self.fill_alpha,) self.positions.append( self.create_scatter_plot_item(pos=pca_session[c], size=1, color=col, unit_id=unit_index, session=session_index, @@ -129,7 +132,8 @@ def do_plot(self, vum, data): c = 0 for unit_index in range(len(active[dom])): if active[dom][unit_index]: - col = vum.get_colour(unit_index, False, None, True) + col = vum.get_colour(unit_index) + col = tuple(val / 255. for val in col) + (self.fill_alpha,) self.positions.append( self.create_scatter_plot_item(pos=dom_ch_pca[c], size=1, color=col, unit_id=unit_index, session=session_index, diff --git a/swan/views/rate_profile_view.py b/swan/views/rate_profile_view.py index 9f0139f..c24aed0 100644 --- a/swan/views/rate_profile_view.py +++ b/swan/views/rate_profile_view.py @@ -385,7 +385,7 @@ def do_plot(self, vum, data): if active[session][global_unit_id]: unit = vum.get_realunit(session, global_unit_id, data) spiketrain = data.get_data("spiketrain", unit) - col = vum.get_colour(global_unit_id, False, layer, False) + col = vum.get_colour(global_unit_id) self.datas[(session, global_unit_id)] = [spiketrain, col, clickable] if self.trigger_event in self.events.keys(): diff --git a/swan/views/virtual_units_view.py b/swan/views/virtual_units_view.py index fe7a0d2..8920fdb 100644 --- a/swan/views/virtual_units_view.py +++ b/swan/views/virtual_units_view.py @@ -37,6 +37,7 @@ def __init__(self, *args, **kwargs): layout = QtWidgets.QVBoxLayout() self.pg_canvas = pg.PlotWidget() + self.pg_canvas.getViewBox().invertY(True) self.details = QtWidgets.QWidget() details_layout = QtWidgets.QHBoxLayout() @@ -166,14 +167,14 @@ def _add_mesh_item(self, mapping_array, channels, channel_stops): for channel, channel_stop in zip(channels, channel_stops): line = pg.InfiniteLine( - pos=channel_stop, + pos=channel_stop + 0.5, angle=0, pen=pg.fn.mkPen(color='w'), label=f"Channel {channel}", labelOpts={ "movable": True, "position": 0.9, - "anchors": [(0.5, 1), (0.5, 1)] + "anchors": [(0.5, 0), (0.5, 0)] } ) self.pg_canvas.addItem(line) @@ -313,8 +314,8 @@ def _prepare_data(self, args): # User only specified z elif len(args) == 1: # If x and y is None, the polygons will be displaced on a grid - x = np.arange(0, args[0].shape[0] + 1, 1) - y = np.arange(0, args[0].shape[1] + 1, 1) + x = np.arange(0, args[0].shape[0] + 1, 1) + 0.5 # +0.5 to align the polygons with the ticklabels in 1-order + y = np.arange(0, args[0].shape[1] + 1, 1) + 0.5 # +0.5 to align the polygons with the ticklabels in 1-order self.x, self.y = np.meshgrid(x, y, indexing='ij') self.z = args[0] diff --git a/swan/views/waveforms_3d_view.py b/swan/views/waveforms_3d_view.py index 59d1366..3d133f0 100644 --- a/swan/views/waveforms_3d_view.py +++ b/swan/views/waveforms_3d_view.py @@ -114,7 +114,7 @@ def do_plot(self, vum, data): if layer == "average": for unit_index in range(len(active)): if any(active[unit_index]): - col = vum.get_colour(unit_index, False, layer, True) + col = vum.get_colour(unit_index) + (self.fill_alpha,) zs = [] for session_index in range(len(active[unit_index])): if active[unit_index][session_index]: @@ -132,7 +132,7 @@ def do_plot(self, vum, data): elif layer == "standard deviation": for unit_index in range(len(active)): if any(active[unit_index]): - col = vum.get_colour(unit_index, False, layer, True) + col = vum.get_colour(unit_index) + (self.fill_alpha,) zs = [] length = 0 for session_index in range(len(active[unit_index])): diff --git a/swan/virtual_unit_map.py b/swan/virtual_unit_map.py index aea4feb..9a206a8 100644 --- a/swan/virtual_unit_map.py +++ b/swan/virtual_unit_map.py @@ -9,7 +9,6 @@ """ import numpy as np from scipy.spatial.distance import cdist - from swan.automatic_mapping import SwanImplementation from swan.gui.palettes import UNIT_COLORS @@ -58,23 +57,13 @@ def set_initial_map(self, data): """ maximum_units = sum(data.total_units_per_block) - self.total_units = maximum_units - mapping = [] - for session in range(len(data.blocks)): - mapping.append([]) - count = 1 - for global_unit_id in range(maximum_units): - try: - unit_description = data.blocks[session].channel_indexes[0].units[global_unit_id].description.split() - - if "unclassified" in unit_description or "noise" in unit_description: - mapping[session].append(0) - else: - mapping[session].append(count) - count += 1 - except IndexError: - mapping[session].append(0) + mapping = np.zeros((len(data.total_units_per_block), maximum_units), dtype=int).tolist() + for s, session in enumerate(data.blocks): + for pos in range(data.total_units_per_block[s]): + mapping[s][pos] = pos+1 + + self.total_units = maximum_units self.mapping = mapping self.visible = [[True for unit in session] for session in mapping] self.update_active() @@ -85,7 +74,7 @@ def set_map_from_dataframe(self, dataframe): for session_id in range(vmap.shape[0]): session_frame = dataframe.loc[dataframe.session == session_id] for global_unit_id, real_unit_id in zip(session_frame.label, session_frame.unit): - vmap[session_id][global_unit_id] = real_unit_id + vmap[session_id][global_unit_id] = real_unit_id + 1 self.mapping = vmap.astype(np.int32).tolist() self.visible = [[True for unit in session] for session in vmap] @@ -138,11 +127,7 @@ def get_realunit(self, session_index, unit_index, data): """ virtual_unit = self.mapping[session_index][unit_index] - if "unclassified" not in data.blocks[session_index].channel_indexes[0].units[0].description.split(): - real_unit = data.blocks[session_index].channel_indexes[0].units[virtual_unit - 1] - else: - real_unit = data.blocks[session_index].channel_indexes[0].units[virtual_unit] - # real_unit = data.blocks[session_index].channel_indexes[0].units[virtual_unit] + real_unit = data.blocks[session_index].groups[virtual_unit - 1] return real_unit def swap(self, session_index, first_unit_index, second_unit_index): @@ -232,7 +217,7 @@ def get_active(self): def get_color_list(self): return self.colors - def get_colour(self, global_unit_id, mpl=False, layer=None, pqt=False): + def get_colour(self, global_unit_id): """ Returns the color for the given unit row. @@ -254,25 +239,14 @@ def get_colour(self, global_unit_id, mpl=False, layer=None, pqt=False): """ global_unit_id = global_unit_id % self.number_of_colors - if mpl: - col = (self.colors[global_unit_id][0] / 255., - self.colors[global_unit_id][1] / 255., - self.colors[global_unit_id][2] / 255.) - if layer == "standard deviation": - col = (col[0] / 2., col[1] / 2., col[2] / 2.) - elif layer == "session": - col = (col[0] / 2., col[1] / 2., col[2] / 2.) - elif pqt: - col = [self.colors[global_unit_id][0] / 255., - self.colors[global_unit_id][1] / 255., - self.colors[global_unit_id][2] / 255., 0.9] - else: - col = self.colors[global_unit_id] + col = ( + self.colors[global_unit_id][0], + self.colors[global_unit_id][1], + self.colors[global_unit_id][2], + ) return col def swan_implementation(self, data, storage): - - print("Map being calculated") swaps = 0 # Retrieve mapping from base backup_mapping = np.array(storage.get_map().mapping.copy()).T @@ -458,13 +432,11 @@ def calculate_mapping(self, data, storage, automatic_mapping=0, parent=None): """ if automatic_mapping == 0: - print("New Implementation") # self.swan_implementation(data=data, base=base) algorithm = SwanImplementation(neodata=data, parent=parent) self.set_map_from_dataframe(algorithm.result) elif automatic_mapping == 1: - print("Old Implementation") self.calculate_mapping_bu(data=data, storage=storage) def calculate_mapping_bu(self, data, storage): @@ -480,26 +452,22 @@ def calculate_mapping_bu(self, data, storage): which will be used to compare the units. """ - print("Map being calculated") - wave_length = data.get_wave_length() for i in range(len(data.blocks) - 1): sessions = np.zeros((sum(data.total_units_per_block), 2, wave_length)) for j, val in enumerate(storage.get_map().mapping[i]): - if val is not 0: + if val != 0: runit = self.get_realunit(i, j, data) - # sessions[j][0] = data.get_data("average", runit) avg = data.get_data("average", runit) sessions[j][0] = avg / np.max(avg) else: sessions[j][0] = np.zeros(wave_length) for j, val in enumerate(storage.get_map().mapping[i + 1]): - if val is not 0: + if val != 0: runit = self.get_realunit(i + 1, j, data) - # sessions[j][1] = data.get_data("average", runit) avg = data.get_data("average", runit) sessions[j][1] = avg / np.max(avg) else: @@ -514,7 +482,7 @@ def calculate_mapping_bu(self, data, storage): print("Executing this in session {}".format(i)) print("J: {}, Val: {}".format(j, val)) - if val is not 0: + if val != 0: print(distances[j]) min_arg = np.argmin(distances[j]) diff --git a/swan/widgets/file_dialog.py b/swan/widgets/file_dialog.py index c36fda3..9782f0f 100644 --- a/swan/widgets/file_dialog.py +++ b/swan/widgets/file_dialog.py @@ -1,25 +1,15 @@ """ -Created on Oct 24, 2013 - -@author: Christoph Gollan - -In this module you can find the :class:`File_Dialog` which lets +In this module you can find the :class:`FileDialog` which lets you choose files from one directory. """ import os from os import curdir -from pyqtgraph.Qt import QtGui, QtWidgets - -try: - from pyqtgraph.Qt.QtCore import QString -except ImportError: - # we are using Python3 so QString is not defined - QString = str +from pyqtgraph.Qt import QtWidgets -from swan.gui.file_dialog_ui import Ui_File_Dialog +from swan.gui.file_dialog_ui import FileDialogUI -class File_Dialog(QtWidgets.QDialog): +class FileDialog(QtWidgets.QDialog): """ A file dialog which can be used to choose a directory and after that you can choose specific files in this directory. @@ -37,7 +27,7 @@ class File_Dialog(QtWidgets.QDialog): """ - def __init__(self, fileext=".nev", *args, **kwargs): + def __init__(self, fileext=".pkl", *args, **kwargs): """ **Properties** @@ -55,19 +45,15 @@ def __init__(self, fileext=".nev", *args, **kwargs): from found files. """ - QtGui.QDialog.__init__(self, *args, **kwargs) - self.ui = Ui_File_Dialog() - self.ui.setupUi(self) - - #properties{ + super(FileDialog, self).__init__(*args, **kwargs) + self.ui = FileDialogUI(self) + self._path = None self._files = [] self._fileext = fileext - self._extlength = len(self._fileext) - #} - self.ui.selectList.setSelectionMode(QtGui.QAbstractItemView.ExtendedSelection) - self.ui.selectionList.setSelectionMode(QtGui.QAbstractItemView.ExtendedSelection) + self.ui.selectList.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) + self.ui.selectionList.setSelectionMode(QtWidgets.QAbstractItemView.ExtendedSelection) self.ui.btnBox.accepted.connect(self.accept) self.ui.btnBox.rejected.connect(self.reject) @@ -75,8 +61,6 @@ def __init__(self, fileext=".nev", *args, **kwargs): self.ui.removeBtn.clicked.connect(self.remove) self.ui.pathBtn.clicked.connect(self.browse) self.ui.pathEdit.textChanged.connect(self.pathChangeEvent) - - #### button handler #### def accept(self): """ @@ -85,9 +69,9 @@ def accept(self): Sets the files and closes the dialog. """ - files = self._get_files(self.ui.selectionList) + files = self._get_files() self._files = [os.path.join(self._path, str(f)) for f in files] - QtGui.QDialog.accept(self) + QtWidgets.QDialog.accept(self) def reject(self): """ @@ -96,7 +80,7 @@ def reject(self): Closes the dialog without setting the files. """ - QtGui.QDialog.reject(self) + QtWidgets.QDialog.reject(self) def add(self): """ @@ -125,10 +109,9 @@ def browse(self): Asks you for an existing directory. """ - self.ui.pathEdit.setText(QtGui.QFileDialog.getExistingDirectory(self, str("Choose a directory"), str(curdir))) - - - #### event handler #### + self.ui.pathEdit.setText( + QtWidgets.QFileDialog.getExistingDirectory(self, str("Choose a directory"), str(curdir)) + ) def pathChangeEvent(self, newpath): """ @@ -141,22 +124,17 @@ def pathChangeEvent(self, newpath): The directory that was selected. """ - #path = str(self.ui.pathEdit.text()) path = str(newpath) self._path = path self.ui.selectList.clear() files = [] - #for p, dirs, files in os.walk(path): - # for f in files: if path: for f in os.listdir(path): if f.endswith(self._fileext): files.append(f) self.fillSelectList(files) - - #### general methods #### def fillSelectList(self, files): """ @@ -172,7 +150,7 @@ def fillSelectList(self, files): files.sort() file_list = [] for f in files: - file_list.append(f[:-self._extlength]) + file_list.append(f) self.ui.selectList.addItems(file_list) def updateSelectionList(self, selection, remove=False): @@ -189,7 +167,7 @@ def updateSelectionList(self, selection, remove=False): Default: False. """ - old_files = self._get_files(self.ui.selectionList) + old_files = self._get_files() files = [] if not remove: for item in selection: @@ -203,23 +181,18 @@ def updateSelectionList(self, selection, remove=False): i = old_files.index(text) old_files.remove(text) self.ui.selectionList.takeItem(i) - - def _get_files(self, listWidget): + + def _get_files(self): """ Getter for the files. - - **Arguments** - - *listWidget* (:class:`PyQt5.QtGui.QListWidgetItem`): - The list widget you want the file names from. **Returns**: list of string The file names. """ files = [] - for i in range(listWidget.count()): - files.append(listWidget.item(i).text()) + for i in range(self.ui.selectionList.count()): + files.append(self.ui.selectionList.item(i).text()) return files def get_files(self): @@ -231,6 +204,3 @@ def get_files(self): """ return self._files - - - diff --git a/swan/widgets/gl_view_widget.py b/swan/widgets/gl_view_widget.py index 7a27fbf..bce642b 100644 --- a/swan/widgets/gl_view_widget.py +++ b/swan/widgets/gl_view_widget.py @@ -23,7 +23,7 @@ def __init__(self, app=None): gl.GLViewWidget.__init__(self, parent=app) self.clickable = False - self._mouse_click_pos = [] + self._mouse_click_pos = QtCore.QPoint() self.means = [] self.candidates = [] diff --git a/swan/widgets/mypgwidget.py b/swan/widgets/mypgwidget.py index 341ae07..9986670 100644 --- a/swan/widgets/mypgwidget.py +++ b/swan/widgets/mypgwidget.py @@ -150,16 +150,16 @@ def get_item(self, item): def highlight_curve_from_plot(self, plot, select): scatter_plot_mean = next((x for x in self.means if (x.opts["session"], x.opts["unit_id"]) == plot.pos), None) scatter_plot_pos = next((x for x in self.positions if (x.opts["session"], x.opts["unit_id"]) == plot.pos), None) - suggested_plots = [x for x in self.positions if - x.opts["session"] == plot.pos[0] and x not in self.selected_plots] + # suggested_plots = [x for x in self.positions if + # x.opts["session"] == plot.pos[0] and x not in self.selected_plots] if scatter_plot_mean is not None: if select: - self.clear_suggests() - for suggested_plot in suggested_plots: - if suggested_plot not in self.suggested_plots: - suggested_plot.set_suggest_pen() - self.suggested_plots.append(suggested_plot) + # self.clear_suggests() + # for suggested_plot in suggested_plots: + # if suggested_plot not in self.suggested_plots: + # suggested_plot.set_suggest_pen() + # self.suggested_plots.append(suggested_plot) if scatter_plot_mean not in self.selected_plots: self.selected_plots.append(scatter_plot_mean) if scatter_plot_pos not in self.selected_plots: @@ -177,15 +177,14 @@ def highlight_curve_from_plot(self, plot, select): if scatter_plot_pos in self.selected_plots: scatter_plot_pos.restore_pen() self.selected_plots.remove(scatter_plot_pos) - if not self.selected_plots: - self.clear_suggests() + # if not self.selected_plots: + # self.clear_suggests() def minimumSizeHint(self) -> QtCore.QSize: return QtCore.QSize(500, 400) class MyScatterPlotItem(gl.GLScatterPlotItem, QtWidgets.QGraphicsItem): - sig_clicked = QtCore.pyqtSignal(object) def __init__(self, **kwargs): gl.GLScatterPlotItem.__init__(self, **kwargs) diff --git a/swan/widgets/plot_grid.py b/swan/widgets/plot_grid.py index f023b5b..ad29b1a 100644 --- a/swan/widgets/plot_grid.py +++ b/swan/widgets/plot_grid.py @@ -253,19 +253,19 @@ def do_plot(self, vum, data): plot_widget = self.find_plot(global_unit_id, session) if plot_widget.to_be_updated: plot_widget.clear_() - pen_colour = vum.get_colour(global_unit_id, False, "average", False) + pen_colour = vum.get_colour(global_unit_id) plot_widget.default_pen_colour = pen_colour if active[session][global_unit_id]: unit = vum.get_realunit(session, global_unit_id, data) mean_waveform = data.get_data("average", unit) - all_waveforms = data.get_data("all", unit) - try: - plot_widget.plot_many(all_waveforms[choice(all_waveforms.shape[0], - size=self.sample_waveform_number, - replace=False)], - self._plot_gray) - except ValueError: - plot_widget.plot_many(all_waveforms, self._plot_gray) + # all_waveforms = data.get_data("all", unit) + # try: + # plot_widget.plot_many(all_waveforms[choice(all_waveforms.shape[0], + # size=self.sample_waveform_number, + # replace=False)], + # self._plot_gray) + # except ValueError: + # plot_widget.plot_many(all_waveforms, self._plot_gray) plot_widget.plot(mean_waveform.magnitude, pen_colour) plot_widget.hasPlot = True plot_widget.toggle_colour_strip(pen_colour) @@ -472,9 +472,7 @@ def set_tooltips(self, tooltips): a list of string containing the tool tips for that column. """ - print(len(tooltips)) for col in self._cols.keys(): - print("Setting tooltips for col {}".format(col)) tips = tooltips[col] plots = self._cols[col] for t, plot in zip(tips, plots): diff --git a/swan/widgets/plot_widget.py b/swan/widgets/plot_widget.py index 5a10353..70d9e86 100644 --- a/swan/widgets/plot_widget.py +++ b/swan/widgets/plot_widget.py @@ -296,18 +296,18 @@ def leaveEvent(self, event): event.ignore() -class MultiLine(pg.QtGui.QGraphicsPathItem): +class MultiLine(QtWidgets.QGraphicsPathItem): def __init__(self, data, color): """x and y are 2D arrays of shape (Nplots, Nsamples)""" connect = np.ones(data.shape, dtype=bool) connect[:, -1] = 0 # don't draw the segment between each trace x = np.tile([i for i in range(data.shape[1])], data.shape[0]).reshape(data.shape) self.path = arrayToQPath(x.flatten(), data.flatten(), connect.flatten()) - QtGui.QGraphicsPathItem.__init__(self, self.path) + super(MultiLine, self).__init__(self.path) self.setPen(mkPen(color)) def shape(self): # override because QGraphicsPathItem.shape is too expensive. - return QtGui.QGraphicsItem.shape(self) + return QtWidgets.QGraphicsItem.shape(self) def boundingRect(self): return self.path.boundingRect() diff --git a/swan/widgets/selector_widget.py b/swan/widgets/selector_widget.py index 763cd0a..d18efbd 100644 --- a/swan/widgets/selector_widget.py +++ b/swan/widgets/selector_widget.py @@ -9,6 +9,8 @@ The electrodes on this map are represented by the :class:`SelectorItem`. """ +import csv + from PyQt5 import QtGui, QtCore, QtWidgets @@ -51,8 +53,8 @@ def __init__(self, *args, **kwargs): self._dirty_items = [] self.saved_channels = [] self._sel = None - self.lastchannel = 1 - self.currentchannel = 1 + self.lastchannel = 0 + self.currentchannel = 0 self.autoFillBackground() @@ -113,7 +115,7 @@ def set_channels(self, channel_list=None): else: j = 0 for i in range(9, -1, -1): - channels = range(1+j*10, 11+j*10) + channels = range(j*10, 10+j*10) items = [s for s in self._items if s.pos[0] == i] for k in range(10): items[k].text = str(channels[k]) @@ -239,6 +241,47 @@ def find_saved(self, vum_all): selector_item.repaint() self.saved_channels = saved_channels + def load_connector_map(self, filename, select_channel): + """ + Loads a connector map given as a .csv file. + + The file has to contain two columns. The first will be ignored but must exist + (e.g. the numbers 1-100) and the other one has to contain the mapped channel numbers. + Choose **,** as delimiter. + + **Arguments** + + *filename* (string): + The csv file to load. + + **Raises**: :class:`ValueError` + If the connector map could not be loaded. + + """ + if filename: + delimiter = ',' + try: + with open(filename, "r") as fn: + channel_list = [] + reader = csv.reader(fn, delimiter=delimiter) + for row in reader: + # just read the second column + channel_list.append(int(row[1])) + channels = self.get_dirty_channels() + + # overwrite existing mapping + self.set_channels(channel_list) + self.reset_sel() + self.reset_dirty() + + # the dirty channels and the selected one has to be set again + for channel in channels: + self.set_dirty(channel, True) + + self.select_only(select_channel) + except Exception as e: + print(e) + def minimumSizeHint(self) -> QtCore.QSize: return self.sizeHint()