diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..7e645ca Binary files /dev/null and b/.DS_Store differ diff --git a/src/viewephys/.DS_Store b/src/viewephys/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/src/viewephys/.DS_Store differ diff --git a/src/viewephys/gui.py b/src/viewephys/gui.py index a21dd2f..80390f1 100644 --- a/src/viewephys/gui.py +++ b/src/viewephys/gui.py @@ -111,8 +111,10 @@ def on_horizontalSliderReleased(self): if not self.cbs[k].isChecked(): continue if k == 'destripe': - data = fcn_destripe(x=data, fs=self.sr.fs, channel_labels=True, h=self.sr.geometry, neuropixel_version=self.sr.major_version) - self.viewers[k] = viewephys(data, self.sr.fs, channels=self.sr.geometry, title=k, t0=t0 * T_SCALAR, t_scalar=T_SCALAR, a_scalar=A_SCALAR) + data = fcn_destripe(x=data, fs=self.sr.fs, channel_labels=True, h=self.sr.geometry, + neuropixel_version=self.sr.major_version) + self.viewers[k] = viewephys(data, self.sr.fs, channels=self.sr.geometry, title=k, + t0=t0 * T_SCALAR, t_scalar=T_SCALAR, a_scalar=A_SCALAR) def closeEvent(self, event): for k in self.viewers: @@ -121,6 +123,96 @@ def closeEvent(self, event): ev.close() self.close() +# =------ + +class SpikeInterfaceViewer(QtWidgets.QMainWindow): + def __init__(self, recording, save_path_picks=None, *args, **kwargs): + """ + :param parent: + :param sr: ibllib.io.spikeglx.Reader instance + """ + super(SpikeInterfaceViewer, self).__init__(*args, *kwargs) + self.settings = QtCore.QSettings('int-brain-lab', 'SpikeInterfaceViewer') + uic.loadUi(Path(__file__).parent.joinpath('nav_file.ui'), self) + self.setWindowIcon(QtGui.QIcon(str(Path(__file__).parent.joinpath('viewephys.svg')))) + self.horizontalSlider.setMinimum(0) + self.horizontalSlider.setSingleStep(1) + self.horizontalSlider.setTickInterval(10) + self.horizontalSlider.sliderReleased.connect(self.on_horizontalSliderReleased) + self.horizontalSlider.valueChanged.connect(self.on_horizontalSliderValueChanged) + self.label_smin.setText('0') + self.show() + self.viewers = {'recording': None} + # self.cbs = {'butterworth': self.cb_butterworth_ap, 'destripe': self.cb_destripe_ap} + self.recording = recording + # Set save path for picks + if save_path_picks is None: + save_path_picks = Path(__file__).parent.joinpath('picks.csv') # This will save into the viewephys folder + elif save_path_picks.suffix != '.csv': + raise ValueError('Extension of save file must be .csv') + self.save_path_picks = save_path_picks + self.set_recording() + + def set_recording(self, *args, **kwargs): + # enable and set slider + num_samples = self.recording.get_num_samples() + fs = self.recording.sampling_frequency + total_duration = self.recording.get_total_duration() + num_channel = self.recording.get_num_channels() + self.horizontalSlider.setMaximum(int(np.floor(num_samples / NSAMP_CHUNK))) + tmax = np.floor(num_samples / NSAMP_CHUNK) * NSAMP_CHUNK / fs + + self.label_smax.setText(f"{tmax:0.2f}s") + tlabel = f'{total_duration} seconds long \n' \ + f'{fs} Hz Sampling Frequency \n' \ + f'{num_channel} Channels' + self.label.setText(tlabel) + first = 0 # first sample + self.horizontalSlider.setValue(first) + self.horizontalSlider.setEnabled(True) + self.on_horizontalSliderReleased() + + for k in self.viewers: + # Propagate save path to each view + self.viewers[k].save_path_picks = self.save_path_picks + self.viewers[k].current_sample0 = first + # TODO make sure picks df remain across views + # Load if exists + if self.save_path_picks.exists(): + df = pd.read_csv(self.save_path_picks) + self.viewers[k].ctrl.model.pickspikes.load_df(df) + self.viewers[k].update_pick_scatter() + + + def on_horizontalSliderValueChanged(self): + tcur = self.horizontalSlider.value() * NSAMP_CHUNK / self.recording.sampling_frequency + self.label_sval.setText(f"{tcur:0.2f}s") + + def on_horizontalSliderReleased(self): + first = int(float(self.horizontalSlider.value()) * NSAMP_CHUNK) + last = first + int(NSAMP_CHUNK) + data = self.recording.get_traces(start_frame=first, end_frame=last).T + + t0 = first / self.recording.sampling_frequency * 0 + # TODO if t0 is not zero the sliders bugs and does not lead to display change (empty) + + for k in self.viewers: + self.viewers[k] = viewephys(data, self.recording.sampling_frequency, + channels=None, title=k, + t0=t0 * T_SCALAR, t_scalar=T_SCALAR, a_scalar=A_SCALAR) + self.viewers[k].current_sample0 = first + self.viewers[k].update_pick_scatter() + + + def closeEvent(self, event): + for k in self.viewers: + ev = self.viewers[k] + if ev is not None: + ev.close() + self.close() + +#----- + class PickSpikes(): @@ -134,6 +226,7 @@ def init_df(self, nrow=0): 'trace': np.zeros(nrow, dtype=np.int32) * -1, 'amp': np.zeros(nrow, dtype=np.int32), 'group': np.zeros(nrow, dtype=np.int32), + 'sample0': np.zeros(nrow, dtype=np.int32) }) return init_df @@ -151,19 +244,20 @@ def load_df(self, df): if isinstance(df, pd.DataFrame): # check all keys are in - indxmissing = np.where(~df.columns.isin(default_df.columns))[0] + indxmissing = np.where(~default_df.columns.isin(df.columns))[0] if len(indxmissing) > 0: raise ValueError(f'df does not contain column {default_df.columns[indxmissing]}') self.update_pick(df) else: raise ValueError('df input is not pd.DataFrame') - def new_row_frompick(self, sample=None, trace=None, amp=None, group=None): + def new_row_frompick(self, sample=None, trace=None, amp=None, group=None, sample0=None): new_row = self.init_df(nrow=1) new_row['sample'] = sample new_row['trace'] = trace new_row['amp'] = amp new_row['group'] = group + new_row['sample0'] = sample0 return new_row def add_spike(self, new_row): @@ -184,7 +278,6 @@ def remove_spike(self, indx_remove): df_updated = df_updated.reset_index(drop=True) self.update_pick(df_updated) - def indx_select(self, sample, trace, s_range=0.5 * 30000, tr_range=3): iclose = np.where(np.logical_and( np.abs(self.picks['sample'] - sample) <= (s_range + 1), @@ -192,6 +285,11 @@ def indx_select(self, sample, trace, s_range=0.5 * 30000, tr_range=3): ))[0] return iclose + def save_picks(self, save_path): + self.picks.to_csv(save_path, index=False) + # chose format CSV output + + class EphysViewer(EasyQC): keyPressed = QtCore.pyqtSignal(int) @@ -202,6 +300,8 @@ def __init__(self, *args, **kwargs): self.menufile.setEnabled(True) self.settings = QtCore.QSettings('int-brain-lab', 'EphysViewer') self.header_curves = {} + self.current_sample0 = 0 + self.save_path_picks = None # menus handling # menu pick self.menupick = self.menuBar().addMenu('&Pick') @@ -278,6 +378,18 @@ def on_key_picking_mode(self, key): match key: case QtCore.Qt.Key.Key_Space: self.ctrl.model.pick_group += 1 + case QtCore.Qt.Key.Key_S: + print(f"Saved picks to: {self.save_path_picks}") + self.ctrl.model.pickspikes.save_picks(self.save_path_picks) + + def update_pick_scatter(self): + # updates scatter plot with only picks from T0 + df = self.ctrl.model.pickspikes.picks + df_local_picks = df.loc[df["sample0"] == self.current_sample0] + + self.ctrl.add_scatter(df_local_picks['sample'] * self.ctrl.model.si, + df_local_picks['trace'], + label='_picks', rgb=PICK_COLOR) def mouseClickPickingEvent(self, event): """ @@ -294,6 +406,8 @@ def mouseClickPickingEvent(self, event): return TR_RANGE = 3 S_RANGE = int(0.5 / self.ctrl.model.si) + # TODO modify s_range so it is scaled according to zoom, + # otherwise can be hard to delete when zoomed out qxy = self.imageItem_seismic.mapFromScene(event.scenePos()) s, tr = (qxy.x(), qxy.y()) # if event.buttons() == QtCore.Qt.MiddleButton: @@ -331,13 +445,11 @@ def mouseClickPickingEvent(self, event): group = 0 # TODO group # Create new row new_row = self.ctrl.model.pickspikes.new_row_frompick( - sample=tmax, trace=xmax, amp=amp, group=group) + sample=tmax, trace=xmax, amp=amp, group=group, sample0=self.current_sample0) self.ctrl.model.pickspikes.add_spike(new_row=new_row) # updates scatter plot - self.ctrl.add_scatter(self.ctrl.model.pickspikes.picks['sample'] * self.ctrl.model.si, - self.ctrl.model.pickspikes.picks['trace'], - label='_picks', rgb=PICK_COLOR) + self.update_pick_scatter() def save_current_plot(self, filename): """ diff --git a/src/viewephys/picks.csv b/src/viewephys/picks.csv new file mode 100644 index 0000000..eb92427 --- /dev/null +++ b/src/viewephys/picks.csv @@ -0,0 +1,4 @@ +sample,trace,amp,group,sample0 +2717,297,26411326.0,0,330000 +3612,296,-75356264.0,0,330000 +3991,297,-64742784.0,0,330000 diff --git a/src/viewephys/raster.py b/src/viewephys/raster.py index 399fd63..9b5498d 100644 --- a/src/viewephys/raster.py +++ b/src/viewephys/raster.py @@ -15,7 +15,7 @@ import one.alf.io as alfio from one.alf.files import get_session_path import spikeglx -from ibldsp import voltage, utils +from ibldsp import voltage from iblatlas.atlas import BrainRegions from viewephys.gui import viewephys, SNS_PALETTE @@ -181,9 +181,10 @@ def show_ephys(self, t0, tlen=.4): sos = scipy.signal.butter(**butter_kwargs, output='sos') butt = scipy.signal.sosfiltfilt(sos, raw) destripe = voltage.destripe(raw, fs=self.data.sr.fs, channel_labels=True) - self.eqc_raw = viewephys(butt, self.data.sr.fs, channels=self.data.channels, br=self.data.br, title='butt', t0=t0, t_scalar=1) - self.eqc_des = viewephys(destripe, self.data.sr.fs, channels=self.data.channels, br=self.data.br, title='destripe', t0=t0, t_scalar=1) - stripes_noise = 20 * np.log10(np.median(utils.rms(butt - destripe))) + self.eqc_raw = viewephys(butt, self.data.sr.fs, channels=self.data.channels, + br=self.data.br, title='butt', t0=t0, t_scalar=1) + self.eqc_des = viewephys(destripe, self.data.sr.fs, channels=self.data.channels, + br=self.data.br, title='destripe', t0=t0, t_scalar=1) eqc_xrange = [t0 + tlen / 2 - 0.01, t0 + tlen / 2 + 0.01] self.eqc_des.viewBox_seismic.setXRange(*eqc_xrange) self.eqc_raw.viewBox_seismic.setXRange(*eqc_xrange) diff --git a/src/viewephys/spike_interface.ui b/src/viewephys/spike_interface.ui new file mode 100644 index 0000000..74e13e0 --- /dev/null +++ b/src/viewephys/spike_interface.ui @@ -0,0 +1,153 @@ + + + MainWindow + + + true + + + + 0 + 0 + 785 + 295 + + + + Spike Interface Viewer + + + + + + + false + + + Qt::Horizontal + + + + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + + Dataset Info + + + Qt::AlignCenter + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + + + 160 + 0 + + + + + + + + + + + + 16777215 + 40 + + + + QFrame::StyledPanel + + + QFrame::Raised + + + + + + smin + + + + + + + sval + + + Qt::AlignCenter + + + + + + + smax + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + + + + + + + + + + 0 + 0 + 785 + 22 + + + + + File + + + + + + + + open + + + + + open live recording + + + + + + diff --git a/src/viewephys/tests/test_model.py b/src/viewephys/tests/test_model.py index 5d6aa0d..92d2492 100644 --- a/src/viewephys/tests/test_model.py +++ b/src/viewephys/tests/test_model.py @@ -14,4 +14,3 @@ def test_model_dataclass(): ProbeData(spikes=spikes, clusters=clusters, channels=channels) ProbeData(spikes=pd.DataFrame(spikes), clusters=pd.DataFrame(clusters), channels=pd.DataFrame(channels)) - diff --git a/src/viewephys/tests/test_pick.py b/src/viewephys/tests/test_pick.py index 62be664..7c5204d 100644 --- a/src/viewephys/tests/test_pick.py +++ b/src/viewephys/tests/test_pick.py @@ -3,7 +3,7 @@ import pandas as pd ps = PickSpikes() -DEFAULT_DF_COLUMNS = ['sample', 'trace', 'amp', 'group'] +DEFAULT_DF_COLUMNS = ['sample', 'trace', 'amp', 'group', 'sample0'] def test_init_df(): @@ -16,7 +16,7 @@ def test_init_df(): def test_new_row_frompick(): - new_row = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4) + new_row = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4, sample0=0) # Check size np.testing.assert_(new_row.shape[0] == 1) # Check column names @@ -41,8 +41,8 @@ def test_update_pick(): # ---- # Create filled df (2 rows) - df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4) - df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5) + df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4, sample0=0) + df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5, sample0=0) df = pd.concat([df1, df2]) # Update ps.update_pick(df) @@ -53,8 +53,8 @@ def test_update_pick(): def test_add_spike(): - df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4) - df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5) + df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4, sample0=0) + df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5, sample0=0) df = pd.concat([df1, df2]) df = df.reset_index(drop=True) ps.update_pick(df1) @@ -68,10 +68,10 @@ def test_add_spike(): def test_remove_spike(): - df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4) - df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5) - df3 = ps.new_row_frompick(sample=6, trace=6, amp=3, group=5) - df4 = ps.new_row_frompick(sample=7, trace=6, amp=3, group=5) + df1 = ps.new_row_frompick(sample=1, trace=2, amp=3, group=4, sample0=0) + df2 = ps.new_row_frompick(sample=3, trace=2, amp=3, group=5, sample0=0) + df3 = ps.new_row_frompick(sample=6, trace=6, amp=3, group=5, sample0=0) + df4 = ps.new_row_frompick(sample=7, trace=6, amp=3, group=5, sample0=0) df = pd.concat([df1, df2, df3, df4]) df = df.reset_index(drop=True) # Update @@ -84,3 +84,8 @@ def test_remove_spike(): ps.remove_spike(indx_remove=indx_remove) pd.testing.assert_frame_equal(ps.picks, df_test) + + +def test_load_df(): + df1 = ps.new_row_frompick() + ps.load_df(df1) # check no raise