Skip to content

Commit

Permalink
Fix multi-skeleton loading (#79)
Browse files Browse the repository at this point in the history
* Add multi skeleton loading test

* Fix
  • Loading branch information
talmo authored Apr 14, 2024
1 parent 304a9d8 commit 12b15fa
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 2 deletions.
4 changes: 2 additions & 2 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,11 +207,11 @@ def read_skeletons(labels_path: str) -> list[Skeleton]:

# Re-index correctly.
skeleton_node_inds = [node["id"] for node in skel["nodes"]]
node_names = [node_names[i] for i in skeleton_node_inds]
sorted_node_names = [node_names[i] for i in skeleton_node_inds]

# Create nodes.
nodes = []
for name in node_names:
for name in sorted_node_names:
nodes.append(Node(name=name))

# Create edges.
Expand Down
28 changes: 28 additions & 0 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,31 @@ def test_write_labels(centered_pair, slp_real_data, tmp_path):
assert len(saved_labels.skeletons) == len(labels.skeletons) == 1
assert saved_labels.skeleton.name == labels.skeleton.name
assert saved_labels.skeleton.node_names == labels.skeleton.node_names


def test_load_multi_skeleton(tmpdir):
"""Test loading multiple skeletons from a single file."""
skel1 = Skeleton()
skel1.add_node(Node("n1"))
skel1.add_node(Node("n2"))
skel1.add_edge("n1", "n2")
skel1.add_symmetry("n1", "n2")

skel2 = Skeleton()
skel2.add_node(Node("n3"))
skel2.add_node(Node("n4"))
skel2.add_edge("n3", "n4")
skel2.add_symmetry("n3", "n4")

skels = [skel1, skel2]
labels = Labels(skeletons=skels)
write_metadata(tmpdir / "test.slp", labels)

loaded_skels = read_skeletons(tmpdir / "test.slp")
assert len(loaded_skels) == 2
assert loaded_skels[0].node_names == ["n1", "n2"]
assert loaded_skels[1].node_names == ["n3", "n4"]
assert loaded_skels[0].edge_inds == [(0, 1)]
assert loaded_skels[1].edge_inds == [(0, 1)]
assert loaded_skels[0].flipped_node_inds == [1, 0]
assert loaded_skels[1].flipped_node_inds == [1, 0]

0 comments on commit 12b15fa

Please sign in to comment.