Skip to content

Commit

Permalink
Use solenoid time for trials.reward_times if available
Browse files Browse the repository at this point in the history
- fixes #147
  • Loading branch information
bjhardcastle committed Oct 23, 2024
1 parent 6efde07 commit f70138b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/npc_sessions/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,8 @@ def set_lazy_eval(
kwargs |= {"sync": self.sync_data}
if self.is_ephys and self.is_sync:
kwargs |= {"ephys_recording_dirs": self.ephys_recording_dirs}

if (reward_times := getattr(self, "_reward_times_with_duration", None)) is not None:
kwargs |= {"reward_times_with_duration": reward_times.timestamps}
# set items in LazyDict for postponed evaluation
if "RFMapping" in stim_filename:
# create two separate trials tables
Expand Down
24 changes: 18 additions & 6 deletions src/npc_sessions/trials/TaskControl/DynamicRouting1.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,20 +540,32 @@ def response_time(self) -> npt.NDArray[np.float64]:
@npc_io.cached_property
def reward_time(self) -> npt.NDArray[np.floating]:
"""delivery time of water reward, for contingent and non-contingent rewards"""
all_reward_times = npc_stim.safe_index(self._flip_times, self._sam.rewardFrames)
all_reward_times = all_reward_times[all_reward_times <= self.stop_time[-1]]
all_reward_trials = (
if (
(solenoid_times := getattr(self, "_reward_times_with_duration", None)) is not None
and len(solenoid_times) >= len(np.where(self.is_rewarded)[0])
):
logger.info(f'Using solenoid opening time on sync for `reward_time`')
all_reward_times = solenoid_times
else:
logger.info(f'Using flip time of each TaskControl frame for `reward_time`')
all_reward_times = npc_stim.safe_index(self._flip_times, self._sam.rewardFrames)
all_reward_times = all_reward_times[
(all_reward_times >= self.start_time[0]) &
(all_reward_times <= self.stop_time[-1])
]
trial_idx_from_rewards = (
np.searchsorted(
self.start_time,
all_reward_times,
side="right",
)
- 1
)
assert len(is_rewarded := np.where(self.is_rewarded)[0]) <= len(trial_idx_from_rewards)
reward_time = np.full(self._len, np.nan)
if np.all(np.where(self.is_rewarded)[0] == all_reward_trials):
# expected single reward per trial
reward_time[all_reward_trials] = all_reward_times
if np.all(is_rewarded == trial_idx_from_rewards):
# expected case: single reward per trial
reward_time[trial_idx_from_rewards] = all_reward_times
else:
# mismatch between reward times and trials that are marked as having rewards
for trial_idx in np.where(self.is_rewarded)[0]:
Expand Down

0 comments on commit f70138b

Please sign in to comment.