Skip to content

Commit

Permalink
append_nwb_training update
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 19, 2024
1 parent dcc3da8 commit 0ba071e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
1 change: 1 addition & 0 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def save_nwb(
as_training: If `True`, save the dataset as a training dataset.
append: If `True` (the default), append to existing NWB file. File will be
created if it does not exist.
img_paths: Optional list of image paths to save to the NWB file.
See also: nwb.write_nwb, nwb.append_nwb
"""
Expand Down
17 changes: 16 additions & 1 deletion sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def labels_to_pose_training(
if not isinstance(labels, Labels):
raise ValueError("The input must be a SLEAP Labels object.")

skeletons_list: list[Skeleton] = [] # type: ignore[assignment]
skeletons = Skeletons(skeletons=skeletons_list) # type: ignore[assignment]
training_frame_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
training_frame_name = name_generator("training_frame")
Expand All @@ -123,6 +125,10 @@ def labels_to_pose_training(
for instance in labeled_frame.instances:
if isinstance(instance, PredictedInstance):
continue

if instance.skeleton not in skeletons_list:
skeletons_list.append(instance.skeleton)
skeletons = Skeletons(skeletons=skeletons_list)
skeleton_instance = instance_to_skeleton_instance(instance)
skeleton_instances_list.append(skeleton_instance)

Expand Down Expand Up @@ -568,7 +574,16 @@ def append_nwb_training(
default_metadata = dict(scorer=str(provenance))
sleap_version = provenance.get("sleap_version", None)
default_metadata["source_software_version"] = sleap_version
default_metadata.update(pose_estimation_metadata)

for i, video in enumerate(labels.videos):
video_path = Path(video.filename)
processing_module_name = f"SLEAP_VIDEO_{i:03}_{video_path.stem}"
nwb_processing_module = get_processing_module_for_video(
processing_module_name, nwbfile
)
default_metadata["original_videos"] = [f"{video.filename}"]
default_metadata["labeled_videos"] = [f"{video.filename}"]
default_metadata.update(pose_estimation_metadata)

subject = Subject(subject_id="No specified id", species="No specified species")
nwbfile.subject = subject
Expand Down

0 comments on commit 0ba071e

Please sign in to comment.