From 662f88ec82c7d37a0dc275a9f8a2f0ce909ce766 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 18 Oct 2024 18:24:11 +0100 Subject: [PATCH 1/7] start work on adaptiveTimeout task --- .../__init__.py | 0 .../nate_adaptiveTimeoutChoiceWorld/task.py | 41 +++++++++++++++++++ .../task_parameters.yaml | 0 3 files changed, 41 insertions(+) create mode 100644 iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/__init__.py create mode 100644 iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py create mode 100644 iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/__init__.py b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py new file mode 100644 index 0000000..137d085 --- /dev/null +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py @@ -0,0 +1,41 @@ +""" +This task inherits TrainingChoiceWorldSession with the addition of configurable, adaptive timeouts for incorrect +choices depending on the stimulus contrast. +""" + +import logging +from pathlib import Path + +import yaml + +from iblrig.base_choice_world import TrainingChoiceWorldSession +from iblrig.misc import get_task_arguments +from pybpodapi.state_machine import StateMachine + +log = logging.getLogger('iblrig.task') + + +# read defaults from task_parameters.yaml +with open(Path(__file__).parent.joinpath('task_parameters.yaml')) as f: + DEFAULTS = yaml.safe_load(f) + + +class AdaptiveTimeoutStateMachine(StateMachine): + def add_state(self, **kwargs): + super().add_state(**kwargs) + + +class Session(TrainingChoiceWorldSession): + protocol_name = 'nate_adaptiveTimeoutChoiceWorld' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def _instantiate_state_machine(self, trial_number=None): + return AdaptiveTimeoutStateMachine(self.bpod) + + +if __name__ == '__main__': # pragma: no cover + kwargs = get_task_arguments(parents=[Session.extra_parser()]) + sess = Session(**kwargs) + sess.run() diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml new file mode 100644 index 0000000..e69de29 From 38ec4d28d074fb6b7369cc2ba8d6e8cec60cb004 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 18 Oct 2024 18:39:19 +0100 Subject: [PATCH 2/7] add initial outline for adaptive timeout task --- .../nate_adaptiveTimeoutChoiceWorld/task.py | 33 ++++++++++++++++++- .../task_parameters.yaml | 3 ++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py index 137d085..0fb67e7 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py @@ -28,12 +28,43 @@ def add_state(self, **kwargs): class Session(TrainingChoiceWorldSession): protocol_name = 'nate_adaptiveTimeoutChoiceWorld' - def __init__(self, *args, **kwargs): + def __init__( + self, + *args, + adaptive_delay_nogo=DEFAULTS['ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS'], + adaptive_delay_error=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'], + **kwargs, + ): + self.adaptive_delay_nogo = adaptive_delay_nogo + self.adaptive_delay_error = adaptive_delay_error super().__init__(*args, **kwargs) def _instantiate_state_machine(self, trial_number=None): return AdaptiveTimeoutStateMachine(self.bpod) + @staticmethod + def extra_parser(): + parser = super(Session, Session).extra_parser() + parser.add_argument( + '--adaptive_delay_nogo', + option_strings=['--adaptive_delay_nogo'], + dest='adaptive_delay_nogo', + default=DEFAULTS['ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS'], + nargs='+', + type=float, + help='list of delays for no-go condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)', + ) + parser.add_argument( + '--adaptive_delay_error', + option_strings=['--adaptive_delay_error'], + dest='adaptive_delay_nogo', + default=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'], + nargs='+', + type=float, + help='list of delays for error condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)', + ) + return parser + if __name__ == '__main__': # pragma: no cover kwargs = get_task_arguments(parents=[Session.extra_parser()]) diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml index e69de29..183ddeb 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml @@ -0,0 +1,3 @@ +# define delays for the following contrasts: [1.0, 0.25, 0.125, 0.0625, 0.0] +'ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS': [5.0, 4.0, 3.0, 2.0, 2.0] +'ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS': [5.0, 4.0, 3.0, 2.0, 2.0] \ No newline at end of file From 0911d7e68cadc484bc397afe0001a34ae7ed2a1b Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Fri, 18 Oct 2024 18:57:52 +0100 Subject: [PATCH 3/7] add some details --- .../nate_adaptiveTimeoutChoiceWorld/task.py | 21 ++++++++++++++++++- .../task_parameters.yaml | 13 +++++++++--- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py index 0fb67e7..b431949 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py @@ -21,7 +21,24 @@ class AdaptiveTimeoutStateMachine(StateMachine): + + def __init__( + self, + bpod, + adaptive_delay_nogo, + adaptive_delay_error + ): + super().__init__(bpod) + self.adaptive_delay_nogo = adaptive_delay_nogo + self.adaptive_delay_error = adaptive_delay_error + + def add_state(self, **kwargs): + match kwargs['state_name']: + case 'nogo': + pass + case 'error': + pass super().add_state(**kwargs) @@ -38,9 +55,11 @@ def __init__( self.adaptive_delay_nogo = adaptive_delay_nogo self.adaptive_delay_error = adaptive_delay_error super().__init__(*args, **kwargs) + assert len(self.adaptive_delay_nogo) == len(self.task_params.CONTRAST_SET) + assert len(self.adaptive_delay_error) == len(self.task_params.CONTRAST_SET) def _instantiate_state_machine(self, trial_number=None): - return AdaptiveTimeoutStateMachine(self.bpod) + return AdaptiveTimeoutStateMachine(self.bpod, self.adaptive_delay_nogo, self.adaptive_delay_error) @staticmethod def extra_parser(): diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml index 183ddeb..4c5aada 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml @@ -1,3 +1,10 @@ -# define delays for the following contrasts: [1.0, 0.25, 0.125, 0.0625, 0.0] -'ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS': [5.0, 4.0, 3.0, 2.0, 2.0] -'ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS': [5.0, 4.0, 3.0, 2.0, 2.0] \ No newline at end of file +'ADAPTIVE_GAIN': True +'ADAPTIVE_REWARD': True +'AG_INIT_VALUE': 8.0 # Adaptive Gain init value. Once the mouse completes 200 response trials whithin a session, this reverts to STIM_GAIN +'CONTRAST_SET_PROBABILITY_TYPE': skew_zero # uniform, skew_zero +'DEBIAS': True # Whether to use debiasing rule or not by repeating error trials +'REWARD_AMOUNT_UL': 3.0 # Reward amount (uL), will oscillate between 1.5 and 3 uL depending on previous sessions if adaptive_reward is True + +'CONTRAST_SET': [1.0, 0.5, 0.25, 0.125, 0.0625, 0.0] +'ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS': [6.0, 5.0, 4.0, 3.0, 2.0, 2.0] +'ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS': [6.0, 5.0, 4.0, 3.0, 2.0, 2.0] From 2336277e59068e6d2f2b0d1b1df440e15923fa83 Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Mon, 21 Oct 2024 12:26:05 +0100 Subject: [PATCH 4/7] Update task.py --- .../nate_adaptiveTimeoutChoiceWorld/task.py | 72 ++++++++++--------- 1 file changed, 40 insertions(+), 32 deletions(-) diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py index b431949..555ced7 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py @@ -5,12 +5,14 @@ import logging from pathlib import Path +from typing import Any +import numpy as np import yaml +from pydantic import NonNegativeFloat -from iblrig.base_choice_world import TrainingChoiceWorldSession from iblrig.misc import get_task_arguments -from pybpodapi.state_machine import StateMachine +from iblrig_tasks._iblrig_tasks_trainingChoiceWorld.task import Session as TrainingCWSession log = logging.getLogger('iblrig.task') @@ -20,30 +22,14 @@ DEFAULTS = yaml.safe_load(f) -class AdaptiveTimeoutStateMachine(StateMachine): - - def __init__( - self, - bpod, - adaptive_delay_nogo, - adaptive_delay_error - ): - super().__init__(bpod) - self.adaptive_delay_nogo = adaptive_delay_nogo - self.adaptive_delay_error = adaptive_delay_error +class AdaptiveTimeoutChoiceWorldTrialData(TrainingCWSession.TrialDataModel): + adaptive_delay_nogo: NonNegativeFloat + adaptive_delay_error: NonNegativeFloat - def add_state(self, **kwargs): - match kwargs['state_name']: - case 'nogo': - pass - case 'error': - pass - super().add_state(**kwargs) - - -class Session(TrainingChoiceWorldSession): +class Session(TrainingCWSession): protocol_name = 'nate_adaptiveTimeoutChoiceWorld' + TrialDataModel = AdaptiveTimeoutChoiceWorldTrialData def __init__( self, @@ -52,14 +38,36 @@ def __init__( adaptive_delay_error=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'], **kwargs, ): - self.adaptive_delay_nogo = adaptive_delay_nogo - self.adaptive_delay_error = adaptive_delay_error + self._adaptive_delay_nogo = adaptive_delay_nogo + self._adaptive_delay_error = adaptive_delay_error super().__init__(*args, **kwargs) - assert len(self.adaptive_delay_nogo) == len(self.task_params.CONTRAST_SET) - assert len(self.adaptive_delay_error) == len(self.task_params.CONTRAST_SET) - - def _instantiate_state_machine(self, trial_number=None): - return AdaptiveTimeoutStateMachine(self.bpod, self.adaptive_delay_nogo, self.adaptive_delay_error) + assert len(self._adaptive_delay_nogo) == len(self.task_params.CONTRAST_SET) + assert len(self._adaptive_delay_error) == len(self.task_params.CONTRAST_SET) + + def draw_next_trial_info(self, **kwargs): + super().draw_next_trial_info(**kwargs) + contrast = self.trials_table.at[self.trial_num, 'contrast'] + index = np.flatnonzero(np.array(self.task_params['CONTRAST_SET']) == contrast)[0] + self.trials_table.at[self.trial_num, 'adaptive_delay_nogo'] = self._adaptive_delay_nogo[index] + self.trials_table.at[self.trial_num, 'adaptive_delay_error'] = self._adaptive_delay_error[index] + + @property + def feedback_nogo_delay(self): + return self.trials_table.at[self.trial_num, 'adaptive_delay_nogo'] + + @property + def feedback_error_delay(self): + return self.trials_table.at[self.trial_num, 'adaptive_delay_error'] + + def show_trial_log(self, extra_info: dict[str, Any] | None = None, log_level: int = logging.INFO): + trial_info = self.trials_table.iloc[self.trial_num] + info_dict = { + 'Adaptive no-go delay': f'{trial_info.adaptive_delay_nogo:.2f} s', + 'Adaptive error delay': f'{trial_info.adaptive_delay_error:.2f} s', + } + if isinstance(extra_info, dict): + info_dict.update(extra_info) + super().show_trial_log(extra_info=info_dict, log_level=log_level) @staticmethod def extra_parser(): @@ -71,7 +79,7 @@ def extra_parser(): default=DEFAULTS['ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS'], nargs='+', type=float, - help='list of delays for no-go condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)', + help='list of delays for no-go condition (contrasts: 1.0, 0.5, 0.25, 0.125, 0.0625, 0.0)', ) parser.add_argument( '--adaptive_delay_error', @@ -80,7 +88,7 @@ def extra_parser(): default=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'], nargs='+', type=float, - help='list of delays for error condition (contrasts: 1.0, 0.25, 0.125, 0.0625, 0.0)', + help='list of delays for error condition (contrasts: 1.0, 0.5, 0.25, 0.125, 0.0625, 0.0)', ) return parser From 6ac6f5d20da7b13d3cdb729b829d677f0ffde4bc Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Mon, 21 Oct 2024 14:58:18 +0100 Subject: [PATCH 5/7] change default values --- .../nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml index 4c5aada..ea91878 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task_parameters.yaml @@ -6,5 +6,5 @@ 'REWARD_AMOUNT_UL': 3.0 # Reward amount (uL), will oscillate between 1.5 and 3 uL depending on previous sessions if adaptive_reward is True 'CONTRAST_SET': [1.0, 0.5, 0.25, 0.125, 0.0625, 0.0] -'ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS': [6.0, 5.0, 4.0, 3.0, 2.0, 2.0] -'ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS': [6.0, 5.0, 4.0, 3.0, 2.0, 2.0] +'ADAPTIVE_FEEDBACK_NOGO_DELAY_SECS': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0] +'ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS': [2.0, 2.0, 2.0, 2.0, 2.0, 2.0] From f31f9c241fac39dfbc96dd2d97d7f84c6d58c34a Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Mon, 21 Oct 2024 15:24:40 +0100 Subject: [PATCH 6/7] Update task.py --- iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py index 555ced7..32cedf9 100644 --- a/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py +++ b/iblrig_custom_tasks/nate_adaptiveTimeoutChoiceWorld/task.py @@ -84,7 +84,7 @@ def extra_parser(): parser.add_argument( '--adaptive_delay_error', option_strings=['--adaptive_delay_error'], - dest='adaptive_delay_nogo', + dest='adaptive_delay_error', default=DEFAULTS['ADAPTIVE_FEEDBACK_ERROR_DELAY_SECS'], nargs='+', type=float, From 6f694080aa1d7a0457afe0462ab9b5b54d3205ab Mon Sep 17 00:00:00 2001 From: Florian Rau Date: Mon, 21 Oct 2024 15:37:16 +0100 Subject: [PATCH 7/7] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 797da0e..7ed8c04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "project_extraction" -version = "0.3.0" +version = "0.4.0" description = "Custom extractors for satellite tasks" dynamic = [ "readme" ] keywords = [ "IBL", "neuro-science" ]