Skip to content

Commit

Permalink
Add skeleton symmetry QOL enhancements
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Nov 8, 2024
1 parent 53c1062 commit 483113f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
50 changes: 50 additions & 0 deletions sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def __attrs_post_init__(self):
"""Ensure nodes are `Node`s, edges are `Edge`s, and `Node` map is updated."""
self._convert_nodes()
self._convert_edges()
self._convert_symmetries()
self.rebuild_cache()

def _convert_nodes(self):
Expand Down Expand Up @@ -174,6 +175,44 @@ def _convert_edges(self):

self.edges[i] = Edge(src, dst)

def _convert_symmetries(self):
"""Convert list of symmetric node names or integers to `Symmetry` objects."""
if isinstance(self.symmetries, np.ndarray):
self.symmetries = self.symmetries.tolist()

Check warning on line 181 in sleap_io/model/skeleton.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/skeleton.py#L181

Added line #L181 was not covered by tests

node_names = self.node_names
for i, symmetry in enumerate(self.symmetries):
if type(symmetry) == Symmetry:
continue
node1, node2 = symmetry
if type(node1) == str:
try:
node1 = node_names.index(node1)
except ValueError:
raise ValueError(
f"Node '{node1}' specified in the symmetry list is not in the "
"nodes."
)
if type(node1) == int or (
np.isscalar(node1) and np.issubdtype(node1.dtype, np.integer)
):
node1 = self.nodes[node1]

if type(node2) == str:
try:
node2 = node_names.index(node2)
except ValueError:
raise ValueError(
f"Node '{node2}' specified in the symmetry list is not in the "
"nodes."
)
if type(node2) == int or (
np.isscalar(node2) and np.issubdtype(node2.dtype, np.integer)
):
node2 = self.nodes[node2]

self.symmetries[i] = Symmetry({node1, node2})

def rebuild_cache(self, nodes: list[Node] | None = None):
"""Rebuild the node name/index to `Node` map caches.
Expand Down Expand Up @@ -425,6 +464,17 @@ def add_symmetry(
if symmetry not in self.symmetries:
self.symmetries.append(symmetry)

def add_symmetries(
self, symmetries: list[Symmetry | tuple[NodeOrIndex, NodeOrIndex]]
):
"""Add multiple `Symmetry` relationships to the skeleton.
Args:
symmetries: A list of `Symmetry` objects or 2-tuples of symmetric nodes.
"""
for symmetry in symmetries:
self.add_symmetry(*symmetry)

def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
"""Rename nodes in the skeleton.
Expand Down
14 changes: 14 additions & 0 deletions tests/model/test_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,15 @@ def test_skeleton():
with pytest.raises(ValueError):
Skeleton(["A", "B"], edges=[("A", "C")])

skel = Skeleton(["A", "B"], symmetries=[("A", "B")])
assert skel.symmetry_inds == [(0, 1)]

with pytest.raises(ValueError):
Skeleton(["A", "B"], symmetries=[("a", "B")])

with pytest.raises(ValueError):
Skeleton(["A", "B"], symmetries=[("A", "b")])


def test_skeleton_node_map():
"""Test `Skeleton` node map returns correct nodes."""
Expand Down Expand Up @@ -165,6 +174,11 @@ def test_add_symmetry():
skel.add_symmetry("E", "F")
assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)]

# Add symmetries
skel.add_nodes(["GL", "GR", "HL", "HR"])
skel.add_symmetries([("GL", "GR"), ("HL", "HR")])
assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]


def test_rename_nodes():
"""Test renaming nodes in the skeleton."""
Expand Down

0 comments on commit 483113f

Please sign in to comment.