Skip to content

Commit

Permalink
updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Aug 11, 2024
1 parent 52d45b8 commit bab0917
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 7 deletions.
6 changes: 1 addition & 5 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ def load_nwb(filename: str) -> Labels:
Returns:
The dataset as a `Labels` object.
"""
with NWBHDF5IO(filename, "r", load_namespaces=True) as io:
nwb_processing = io.read().processing
if any("PoseTraining" in module for module in nwb_processing.values()):
return nwb.read_nwb_training(nwb_processing)
return nwb.read_nwb(filename)
return nwb.read_nwb(filename)


def save_nwb(
Expand Down
2 changes: 1 addition & 1 deletion sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ def build_pose_estimation_container_for_track(
for i, video in enumerate(labels.videos):
camera = nwbfile.create_device(
name=f"camera {i}",
description="Camera used to record video {i}",
description=f"Camera used to record video {i}",
manufacturer="No specified manufacturer",
)
cameras.append(camera)
Expand Down
Binary file added tests/data/nwb/minimal_instance.nwb
Binary file not shown.
16 changes: 15 additions & 1 deletion tests/io/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,24 @@ def test_load_slp(slp_typical):

def test_nwb(tmp_path, slp_typical):
labels = load_slp(slp_typical)
save_nwb(labels, tmp_path / "test.nwb")
save_nwb(labels, tmp_path / "test_nwb.nwb", False)
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert type(loaded_labels) == Labels
assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels
assert len(loaded_labels) == len(labels)

labels2 = load_slp(slp_typical)
labels2.videos[0].filename = "test"
save_nwb(labels2, tmp_path / "test_nwb.nwb", append=True)
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert type(loaded_labels) == Labels
assert len(loaded_labels) == (len(labels) + len(labels2))
assert len(loaded_labels.videos) == 2


def test_nwb_training(tmp_path, slp_typical):
labels = load_slp(slp_typical)
save_nwb(labels, tmp_path / "test_nwb.nwb", True)


def test_labelstudio(tmp_path, slp_typical):
Expand Down

0 comments on commit bab0917

Please sign in to comment.