Skip to content

Commit

Permalink
m
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 18, 2024
1 parent 1b70d5c commit c59f996
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
4 changes: 2 additions & 2 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ def save_nwb(
"""
if as_training:
if append and Path(filename).exists():
nwb.append_nwb_training(labels, filename)
nwb.append_nwb(labels, filename, as_training=True)
else:
nwb.write_nwb(labels, filename, None, None, True)
nwb.write_nwb(labels, filename, as_training=True)

else:
if append and Path(filename).exists():
Expand Down
15 changes: 4 additions & 11 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def append_nwb_data(


def append_nwb_training(
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict]
labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None
) -> NWBFile:
"""Append training data from a Labels object to an in-memory NWB file.
Expand Down Expand Up @@ -548,27 +548,20 @@ def append_nwb_training(
)
pose_estimation_series_list.append(pose_estimation_series)

dimensions = np.array([[labels.videos[0].backend.shape[1], labels.videos[0].backend.shape[2]]])
pose_estimation = PoseEstimation(name="pose_estimation",
pose_estimation_series=pose_estimation_series_list,
description="Estimated positions of the nodes in the video",
original_videos=[video.filename for video in labels.videos],
labeled_videos=[video.filename for video in labels.videos],
dimensions=np.array([[labels.videos[0].backend.height, labels.videos[0].backend.width]]),
dimensions=dimensions,
devices=[camera],
scorer="No specified scorer",
source_software="SLEAP",
source_software_version=sleap_version,
skeleton=skeletons_list[0],
)



for lf in labels.labeled_frames:
if lf.has_predicted_instances:
labels_data_df = convert_predictions_to_dataframe(labels)
break
else:
labels_data_df = pd.DataFrame()
behavior_pm.add(pose_estimation)
return nwbfile


Expand Down

0 comments on commit c59f996

Please sign in to comment.