Skip to content

Commit

Permalink
passed all tests
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Aug 14, 2024
1 parent 7a83046 commit 1a7e58b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
19 changes: 14 additions & 5 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,20 +554,29 @@ def write_nwb(
processing_module = nwbfile.processing[
f"SLEAP_VIDEO_000_{Path(labels.videos[0].filename).stem}"
]
pose_estimation = processing_module["track=untracked"]
skeleton = pose_estimation.skeleton
skeletons = Skeletons(skeletons=[skeleton])
try:
pose_estimation = processing_module["track=untracked"]
skeletons = [pose_estimation.skeleton]
except KeyError:
skeletons = []
for i in range(len(labels.tracks)):
pose_estimation = processing_module[f"track=track_{i}"]
skeleton = pose_estimation.skeleton
skeletons.append(skeleton) if skeleton.parent is None else ...

Check warning on line 565 in sleap_io/io/nwb.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/io/nwb.py#L560-L565

Added lines #L560 - L565 were not covered by tests
skeletons = Skeletons(skeletons=skeletons)
processing_module.add(skeletons)
io.write(nwbfile)


def handle_orphan_container_error(labels: Labels, nwbfile: NWBFile) -> NWBFile:
def handle_orphan_container_error(
labels: Labels, nwbfile: NWBFile
) -> tuple[NWBFile, Skeletons]: # type: ignore[return]
"""Handle orphan container error by adding a skeleton to the processing module.
Args:
labels: A general labels object.
nwbfile: An in-memory nwbfile where the data is to be appended.
Returns:
An in-memory nwbfile with the data from the labels object appended.
"""
Expand Down
6 changes: 0 additions & 6 deletions tests/io/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,6 @@ def test_nwb(tmp_path, slp_typical, slp_predictions_with_provenance):
assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels
assert len(loaded_labels) == len(labels)

labels2 = load_slp(slp_predictions_with_provenance)
save_nwb(labels2, tmp_path / "test_nwb2.nwb", False)
loaded_labels2 = load_nwb(tmp_path / "test_nwb2.nwb")
assert type(loaded_labels2) == Labels
assert len(loaded_labels2) == len(labels2)


def test_nwb_training(tmp_path, slp_typical):
labels = load_slp(slp_typical)
Expand Down

0 comments on commit 1a7e58b

Please sign in to comment.