Skip to content

Commit

Permalink
Rm lazy dataframe newscale coords - fixes #118
Browse files Browse the repository at this point in the history
  • Loading branch information
bjhardcastle committed Aug 15, 2024
1 parent 8f7d711 commit dbc1a7e
Showing 1 changed file with 1 addition and 24 deletions.
25 changes: 1 addition & 24 deletions src/npc_sessions/utils/newscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,6 @@ def get_newscale_data(path: npc_io.PathLike) -> pl.DataFrame:
)


def get_newscale_data_lazy(path: npc_io.PathLike) -> pl.LazyFrame:
"""
# >>> df = get_newscale_data_lazy('s3://aind-ephys-data/ecephys_686740_2023-10-23_14-11-05/behavior/log.csv')
"""
# TODO not working with s3 paths
return pl.scan_csv(
source=npc_io.from_pathlike(path).as_posix(),
with_column_names=lambda _: list(NEWSCALE_LOG_COLUMNS),
try_parse_dates=True,
)


def get_newscale_coordinates(
newscale_log_path: npc_io.PathLike,
recording_start_time: (
Expand Down Expand Up @@ -123,11 +111,7 @@ def get_newscale_coordinates(

movement = pl.col(NEWSCALE_LOG_COLUMNS[0])
serial_number = pl.col(NEWSCALE_LOG_COLUMNS[1])
df: pl.DataFrame
try:
df = get_newscale_data_lazy(newscale_log_path) # type: ignore [assignment]
except Exception:
df = get_newscale_data(newscale_log_path)
df = get_newscale_data(newscale_log_path)

# if experiment date isn't in df, the log file didn't cover this experiment -
# we can't continue
Expand All @@ -140,14 +124,9 @@ def get_newscale_coordinates(
pl.col("last_movement_dt").dt.date()
> (start.dt.date() - datetime.timedelta(hours=24))
)
if isinstance(df, pl.LazyFrame):
recent_df = recent_df.collect()
recent_z_values = recent_df["z"].str.strip_chars().cast(pl.Float32).to_numpy()
z_inverted: bool = is_z_inverted(recent_z_values)

if isinstance(df, pl.LazyFrame):
df = df.collect()

df = (
df.filter(movement < start.dt)
.group_by(serial_number)
Expand All @@ -156,8 +135,6 @@ def get_newscale_coordinates(
) # get last-moved for each manipulator
.top_k(6, by=movement)
)
if isinstance(df, pl.LazyFrame):
df = df.collect()

# serial numbers have an extra leading space
manipulators = df.get_column(NEWSCALE_LOG_COLUMNS[1]).str.strip_chars()
Expand Down

0 comments on commit dbc1a7e

Please sign in to comment.