Skip to content

Commit

Permalink
Generate fig3 for units with high context sel index
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Aug 24, 2024
1 parent b221245 commit 3dcc457
Showing 1 changed file with 69 additions and 27 deletions.
96 changes: 69 additions & 27 deletions src/npc_sessions_cache/figures/paper2/fig3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def plot(unit_id: str, stim_names=("vis1", "vis2", "sound1", "sound2")) -> plt.F
session_id = npc_session.SessionRecord(unit_id).id

units_all_sessions = utils.get_component_df("units")
licks_all_sessions = utils.get_component_zarr("spike_times")
spike_times_all_sessions = utils.get_component_zarr("spike_times")
trials_all_sessions = utils.get_component_df("trials")
all_sessions = utils.get_component_df("session")
performance_all_sessions = utils.get_component_df("performance")
Expand All @@ -42,36 +42,36 @@ def plot(unit_id: str, stim_names=("vis1", "vis2", "sound1", "sound2")) -> plt.F

#! session id is without idx for spike times
spike_times_session_id = "_".join(unit_id.split("_")[:2])
subject_lick_times: npt.NDArray = licks_all_sessions[spike_times_session_id][unit_id][:]
if not subject_lick_times.size:
raise ValueError(f"No lick times found for {unit_id}")
unit_spike_times: npt.NDArray = spike_times_all_sessions[spike_times_session_id][unit_id][:]
if not unit_spike_times.size:
raise ValueError(f"No spike times found for {unit_id}")
modality_to_rewarded_stim = {"aud": "sound1", "vis": "vis1"}

# add licks to trials:
# add spikes to trials:
pad_start = 1.5 # seconds
lick_times_by_trial = tuple(
subject_lick_times[slice(start, stop)] if 0 <= start < stop <= len(subject_lick_times) else []
spike_times_by_trial = tuple(
unit_spike_times[slice(start, stop)] if 0 <= start < stop <= len(unit_spike_times) else []
for start, stop in np.searchsorted(
subject_lick_times, trials.select(pl.col("start_time") - pad_start, "stop_time")
unit_spike_times, trials.select(pl.col("start_time") - pad_start, "stop_time")
)
)
if not lick_times_by_trial or not any(np.array(a).any() for a in lick_times_by_trial):
raise ValueError(f"No lick times found matching trial times {unit} - either no task presented or major timing issue")
if not spike_times_by_trial or not any(np.array(a).any() for a in spike_times_by_trial):
raise ValueError(f"No spike times found matching trial times {unit} - either no task presented or major timing issue")
trials = (
trials
.with_columns(
pl.Series(name="lick_times", values=lick_times_by_trial, dtype=pl.List(pl.Float64)), # doesn't handle empty entries well without explicit dtype
pl.Series(name="spike_times", values=spike_times_by_trial, dtype=pl.List(pl.Float64)), # doesn't handle empty entries well without explicit dtype
)
.with_row_index()
.explode("lick_times")
.explode("spike_times")
.with_columns(
stim_centered_lick_times=(
pl.col("lick_times")
- pl.col("stim_start_time").alias("stim_centered_lick_times")
stim_centered_spike_times=(
pl.col("spike_times")
- pl.col("stim_start_time").alias("stim_centered_spike_times")
)
)
.group_by(
pl.all().exclude("lick_times", "stim_centered_lick_times"),
pl.all().exclude("spike_times", "stim_centered_spike_times"),
maintain_order=True,
)
.all()
Expand Down Expand Up @@ -104,7 +104,7 @@ def plot(unit_id: str, stim_names=("vis1", "vis2", "sound1", "sound2")) -> plt.F
# make sure there's no info that will trigger plotting:
is_response=pl.lit(False),
is_rewarded=pl.lit(False),
stim_centered_lick_times=pl.lit([]),
stim_centered_spike_times=pl.lit([]),
)
trials_ = pl.concat([trials_, extra_df])

Expand Down Expand Up @@ -266,19 +266,19 @@ def plot(unit_id: str, stim_names=("vis1", "vis2", "sound1", "sound2")) -> plt.F
if trial["is_reward_scheduled"] and trial["trial_index_in_block"] > 10:
ax.axhspan(ypos - halfline, ypos + halfline, **green_patch_params)

# licks
trial_lick_times = np.array(trial["stim_centered_lick_times"])
# spikes
trial_spike_times = np.array(trial["stim_centered_spike_times"])
eventplot_params = dict(
lineoffsets=ypos,
linewidths=0.3,
linelengths=0.8,
color=[0.6] * 3,
zorder=99,
)
if trial_lick_times.size == 1 and trial_lick_times[0] is None:
if trial_spike_times.size == 1 and trial_spike_times[0] is None:
pass
else:
ax.eventplot(positions=trial_lick_times, **eventplot_params)
ax.eventplot(positions=trial_spike_times, **eventplot_params)

# times of interest
override_params = dict(alpha=1)
Expand Down Expand Up @@ -350,11 +350,11 @@ def plot_(hist, bin_edges, **plot_kwargs):
pl.col("stim_name") == stim,
)
)
a = df["stim_centered_lick_times"].to_numpy()
a = df["stim_centered_spike_times"].to_numpy()
if not a.size:
continue
hist, bin_edges = spikes.makePSTH_numba(
spikes=np.sort(subject_lick_times),
spikes=np.sort(unit_spike_times),
startTimes=np.array(df["stim_start_time"] - pad_start),
windowDur=pad_start + xlim_1, binSize=bin_size_s,
)
Expand All @@ -371,7 +371,7 @@ def plot_(hist, bin_edges, **plot_kwargs):
)
)
hist, bin_edges = spikes.makePSTH_numba(
spikes=np.sort(subject_lick_times),
spikes=np.sort(unit_spike_times),
startTimes=np.array(df["stim_start_time"] - pad_start),
windowDur=pad_start + xlim_1, binSize=bin_size_s,
)
Expand Down Expand Up @@ -491,14 +491,49 @@ def get_unit_ids_shawn_session_list() -> pl.Series:
print(f"Expected {top_k * len(sessions)} units, got {len(units)}")
return units

def get_specific_unit_ids() -> list[str]:
return [
'667252_2023-09-26_C-233',
]

def get_grouped_baseline_psth_parquet():
u = (
pl.read_parquet(r"C:\Users\ben.hardcastle\github\npc_sessions_cache\src\npc_sessions_cache\figures\paper2\baseline_psth.parquet")
.join(utils.get_component_df("units"), on='unit_id', )
)
k = 10
dfs = []
for stim in ("vis", "aud"):
for target in ("target",):
col = pl.col(f"{stim}_{target}_context_index").abs()
dfs.append(
u
.group_by(
'structure'
)
.agg(
[
col.top_k(k).alias('abs_context_index'),
pl.col('unit_id').top_k_by(col, k),
]
)
.explode(pl.all().exclude('structure'))
)

return pl.concat(dfs).sort('abs_context_index', descending=True)

def get_unit_ids_baseline_psth_parquet():
return get_grouped_baseline_psth_parquet()['unit_id']

if __name__ == "__main__":

stim_names = ("sound1", "vis1", "sound2", "vis2")
target_stim_names = ("sound1", "vis1")
pyfile_path = pathlib.Path(__file__)
raise_on_error = False
get_unit_id_func = get_unit_ids_shawn_session_list
for unit_id in sorted(get_unit_id_func()):
get_unit_id_func = get_unit_ids_baseline_psth_parquet
for unit_id in get_unit_id_func():
print(f"plotting {pyfile_path.stem} for {unit_id}")
try:
fig = plot(unit_id, stim_names)
Expand All @@ -508,9 +543,16 @@ def get_unit_ids_shawn_session_list() -> pl.Series:
print(f"failed: {exc!r}")
continue
figsave_path = pyfile_path.with_name(f"{pyfile_path.stem}_{unit_id}")
if get_unit_id_func is get_unit_ids_shawn_session_list:
if get_unit_id_func in (
get_unit_ids_shawn_session_list,
get_unit_ids_baseline_psth_parquet
):
materials_path = pathlib.Path("C:/Users/ben.hardcastle/OneDrive - Allen Institute/Shared Documents - Dynamic Routing/DR Manuscripts/DR Paper 2 - context representations/Figures/Figure 3/materials")
requested_path = materials_path / "top_context_units_by_session"
if get_unit_id_func is get_unit_ids_shawn_session_list:
requested_path = materials_path / "top_context_units_by_session"
elif get_unit_id_func is get_unit_ids_baseline_psth_parquet:
area = get_grouped_baseline_psth_parquet().filter(pl.col('unit_id') == unit_id)['structure'][0]
requested_path = materials_path / "top_context_units_by_area" / area
requested_path.mkdir(exist_ok=True, parents=True)
figsave_path = requested_path / figsave_path.name
fig.savefig(f"{figsave_path}.png", dpi=300, bbox_inches="tight")
Expand Down

0 comments on commit 3dcc457

Please sign in to comment.