Skip to content

Commit

Permalink
Update online_plots.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bimac committed Dec 19, 2024
1 parent 6244318 commit 1f045d6
Showing 1 changed file with 90 additions and 54 deletions.
144 changes: 90 additions & 54 deletions iblrig/gui/online_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
QSizePolicy,
QStyledItemDelegate,
QTableView,
QVBoxLayout,
QWidget,
)

from iblqt.core import DataFrameTableModel
Expand All @@ -43,6 +45,39 @@
from iblrig.raw_data_loaders import bpod_session_data_to_dataframe, load_task_jsonable


class PlotWidget(pg.PlotWidget):
"""PlotWidget with tuned default settings."""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.setBackground('white')
self.plotItem.getViewBox().setBackgroundColor(pg.mkColor(250, 250, 250))
self.plotItem.setMouseEnabled(x=False, y=False)
self.plotItem.setMenuEnabled(False)
self.plotItem.hideButtons()
for axis in ('left', 'bottom'):
self.plotItem.getAxis(axis).setTextPen('k')


class SingleBarChart(PlotWidget):
"""A bar chart with a single column"""

def __init__(self, *args, barBrush='k', **kwargs):
super().__init__(*args, **kwargs)
self.plotItem.getAxis('left').setWidth(40)
self.plotItem.getAxis('left').setGrid(128)
self.plotItem.getAxis('bottom').setLabel(' ')
self.plotItem.getAxis('bottom').setTicks([[(1, ' ')], []])
self.plotItem.getAxis('bottom').setStyle(tickLength=0, tickAlpha=0)
self.plotItem.setXRange(min=0, max=2, padding=0)
self._barGraphItem = pg.BarGraphItem(x=1, width=2, height=0, pen=None, brush=barBrush)
self.addItem(self._barGraphItem)

@Slot(float)
def setValue(self, value: float):
self._barGraphItem.setOpts(height=value)


class TrialsTableModel(DataFrameTableModel):
"""A table model that displays status tips for entries in the trials table."""

Expand Down Expand Up @@ -79,9 +114,9 @@ class TrialsTableView(QTableView):
def __init__(self, parent: QObject):
super().__init__(parent)
self.setMouseTracking(True)
self.setVerticalScrollMode(QAbstractItemView.ScrollPerPixel)
# self.setVerticalScrollMode(QAbstractItemView.ScrollPerPixel)
self.verticalHeader().hide()
# self.horizontalHeader().hide()
self.horizontalHeader().hide()
self.horizontalHeader().setDefaultAlignment(Qt.AlignLeft)
self.horizontalHeader().setSectionResizeMode(QHeaderView.Fixed)
self.horizontalHeader().setStretchLastSection(True)
Expand Down Expand Up @@ -119,6 +154,41 @@ def paintEvent(self, event):
super().paintEvent(event)


class TrialsWidget(QWidget):
trialSelected = Signal(int)

def __init__(self, parent: QObject, model: TrialsTableModel):
super().__init__(parent)
self.model = model

layout = QVBoxLayout(self)
layout.setSpacing(4)
layout.setContentsMargins(0, 8, 0, 36)
self.setLayout(layout)

self.titleLabel = QLabel('Trials History')
self.titleLabel.setAlignment(Qt.AlignHCenter)
font = self.titleLabel.font()
font.setPointSize(11)
self.titleLabel.setFont(font)
layout.addWidget(self.titleLabel)

self.table_view = TrialsTableView(self)
self.table_view.setModel(self.model)
self.table_view.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
self.table_view.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
self.table_view.setColumnHidden(2, True)
self.table_view.setColumnHidden(3, True)
self.table_view.setColumnHidden(4, True)
self.table_view.selectionModel().selectionChanged.connect(self._onSelectionChange)
layout.addWidget(self.table_view)
layout.setStretch(1, 1)

@Slot(QItemSelection, QItemSelection)
def _onSelectionChange(self, selected: QItemSelection, _deselected: QItemSelection):
self.trialSelected.emit(selected.indexes()[0].row())


class OnlinePlotsModel(QObject):
currentTrialChanged = Signal(int)
_trial_data = pd.DataFrame()
Expand Down Expand Up @@ -443,7 +513,7 @@ def __init__(self, raw_data_folder: DirectoryPath, parent: QObject | None = None

self.statusBar().clearMessage()
self.setWindowTitle('Online Plots')
self.setMinimumSize(1024, 768)
self.setMinimumSize(1024, 771)
self.setWindowIcon(QIcon(QPixmap(':/images/iblrig_logo')))

# the frame that contains all the plots
Expand Down Expand Up @@ -475,29 +545,13 @@ def __init__(self, raw_data_folder: DirectoryPath, parent: QObject | None = None
layout.addWidget(subtitle, 1, 0, 1, 3)

# trial data
self.trials = TrialsTableView(self)
self.trials.setModel(self.model.table_model)
self.trials.selectionModel().selectionChanged.connect(self.onSelectionChanged)
self.trials.horizontalHeader().setSectionResizeMode(0, QHeaderView.ResizeMode.ResizeToContents)
self.trials.horizontalHeader().setSectionResizeMode(1, QHeaderView.ResizeMode.ResizeToContents)
self.trials.setColumnHidden(2, True)
self.trials.setColumnHidden(3, True)
self.trials.setColumnHidden(4, True)
self.trials = TrialsWidget(self, self.model.table_model)
self.trials.trialSelected.connect(self.model.setCurrentTrial)
layout.addWidget(self.trials, 2, 0, 2, 1)

# properties common to all pyqtgraph plots
def common_plot_item_props(plot_item: pg.PlotItem):
plot_item.getViewBox().setBackgroundColor(pg.mkColor(250, 250, 250))
plot_item.setMouseEnabled(x=False, y=False)
plot_item.setMenuEnabled(False)
plot_item.hideButtons()
for axis in ('left', 'bottom'):
plot_item.getAxis(axis).setTextPen('k')

# properties common to psychometric/chronometric functions
def common_function_props(plot_widget: pg.PlotWidget) -> dict[Any, pg.PlotDataItem]:
plot_item = plot_widget.plotItem
common_plot_item_props(plot_item)
plot_item.addItem(pg.InfiniteLine(0, 90, 'black'))
for axis in ('left', 'bottom'):
plot_item.getAxis(axis).setGrid(128)
Expand All @@ -521,7 +575,7 @@ def common_function_props(plot_widget: pg.PlotWidget) -> dict[Any, pg.PlotDataIt
return plot_data_items

# psychometric function
self.psychometricFunction = pg.PlotWidget(parent=self, background='white')
self.psychometricFunction = PlotWidget(parent=self)
layout.addWidget(self.psychometricFunction, 2, 1, 1, 1)
self.psychometricFunction.plotItem.setTitle('Psychometric Function', color='k')
self.psychometricFunction.plotItem.getAxis('left').setLabel('Rightward Choices (%)')
Expand All @@ -530,45 +584,30 @@ def common_function_props(plot_widget: pg.PlotWidget) -> dict[Any, pg.PlotDataIt
self.psychometricPlots = common_function_props(self.psychometricFunction)

# chronometric function
self.chronometricFunction = pg.PlotWidget(parent=self, background='white')
self.chronometricFunction = PlotWidget(parent=self)
layout.addWidget(self.chronometricFunction, 3, 1, 1, 1)
self.chronometricFunction.plotItem.setTitle('Chronometric Function', color='k')
self.chronometricFunction.plotItem.getAxis('left').setLabel('Response Time (s)')
self.chronometricFunction.plotItem.setLogMode(x=False, y=True)
self.chronometricFunction.plotItem.setYRange(-1, 2, padding=0.05)
self.chronometricPlots = common_function_props(self.chronometricFunction)

# properties common to all bar charts
def common_bar_chart_props(plot_item: pg.PlotItem):
common_plot_item_props(plot_item)
plot_item.getAxis('left').setWidth(40)
plot_item.getAxis('left').setGrid(128)
plot_item.getAxis('bottom').setLabel(' ')
plot_item.getAxis('bottom').setTicks([[(1, ' ')], []])
plot_item.getAxis('bottom').setStyle(tickLength=0)
plot_item.setXRange(min=0, max=2, padding=0)
plot_item.hoverEvent = self.mouseOverBarChart

# performance chart
self.performanceWidget = pg.PlotWidget(parent=self, background='white')
layout.addWidget(self.performanceWidget, 2, 2, 1, 1)
common_bar_chart_props(self.performanceWidget.plotItem)
self.performanceWidget = SingleBarChart(parent=self)
self.performanceWidget.setMinimumWidth(155)
self.performanceWidget.plotItem.setTitle('Performance', color='k')
self.performanceWidget.plotItem.getAxis('left').setLabel('Correct Choices (%)')
self.performancePlot = pg.BarGraphItem(x=1, width=2, height=0, pen=None, brush='k')
self.performanceWidget.addItem(self.performancePlot)
self.performanceWidget.plotItem.setYRange(0, 105, padding=0)
self.performanceWidget.plotItem.hoverEvent = self.mouseOverBarChart
layout.addWidget(self.performanceWidget, 2, 2, 1, 1)

# reward chart
self.rewardWidget = pg.PlotWidget(parent=self, background='white')
self.rewardWidget.setMinimumWidth(135)
layout.addWidget(self.rewardWidget, 3, 2, 1, 1)
common_bar_chart_props(self.rewardWidget.plotItem)
self.rewardWidget.plotItem.setTitle('Total Reward', color='k')
self.rewardWidget.plotItem.getAxis('left').setLabel('Reward Amount (μl)')
self.rewardPlot = pg.BarGraphItem(x=1, width=2, height=0, pen=None, brush='b')
self.rewardWidget.addItem(self.rewardPlot)
self.rewardWidget = SingleBarChart(parent=self, barBrush='blue')
self.rewardWidget.plotItem.setTitle('Reward Amount', color='k')
self.rewardWidget.plotItem.getAxis('left').setLabel('Total Reward Volume (μl)')
self.rewardWidget.plotItem.setYRange(0, 1050, padding=0)
self.rewardWidget.plotItem.hoverEvent = self.mouseOverBarChart
layout.addWidget(self.rewardWidget, 3, 2, 1, 1)

# bpod data
self.bpodWidget = BpodWidget(self, title='Bpod States and Input Channels')
Expand All @@ -587,26 +626,23 @@ def mouseOverBarChart(self, event):
if event.currentItem == self.performanceWidget.plotItem:
statusbar.showMessage(f'Performance: {self.model.percentCorrect():0.1f}% correct choices')
else:
statusbar.showMessage(f'Total reward amount: {self.model.reward_amount:0.1f} μl')
statusbar.showMessage(f'Total reward volume: {self.model.reward_amount:0.1f} μl')

@Slot(int)
def updatePlots(self, trial: int):
self.title.setText(f'Trial {trial}')
self.bpodWidget.setData(self.model.bpod_data(trial))
self.trials.setCurrentIndex(self.model.table_model.index(trial, 0))
self.trials.table_view.setCurrentIndex(self.model.table_model.index(trial, 0))
if trial == self.model.table_model.columnCount() - 1:
self.trials.scrollToBottom()
for p in self.model.probability_set:
idx = (p, self.model.signed_contrasts)
self.psychometricPlots[p].setData(x=idx[1], y=self.model.psychometrics.loc[idx, 'choice'].to_list())
self.chronometricPlots[p].setData(x=idx[1], y=self.model.psychometrics.loc[idx, 'response_time'].to_list())
self.performancePlot.setOpts(height=self.model.percentCorrect())
self.rewardPlot.setOpts(height=self.model.reward_amount)
self.performanceWidget.setValue(self.model.percentCorrect())
self.rewardWidget.setValue(self.model.reward_amount)
self.update()

def onSelectionChanged(self, selected: QItemSelection, _: QItemSelection):
self.model.setCurrentTrial(selected.indexes()[0].row())

def keyPressEvent(self, event) -> None:
"""Navigate trials using directional keys."""
match event.key():
Expand Down

0 comments on commit 1f045d6

Please sign in to comment.