diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index 17194b5a..8b2cc728 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -11,7 +11,7 @@ import numpy as np -@define(frozen=True, cache_hash=True) +@define(eq=False) class Node: """A landmark type within a `Skeleton`. @@ -171,8 +171,41 @@ def edge_names(self) -> list[str, str]: return [(edge.source.name, edge.destination.name) for edge in self.edges] @property - def flipped_node_inds(self) -> list[int]: - """Returns node indices that should be switched when horizontally flipping.""" + def symmetry_inds(self) -> list[Tuple[int, int]]: + """Symmetry indices as a list of 2-tuples.""" + return [ + tuple(sorted((self.index(symmetry[0]), self.index(symmetry[1])))) + for symmetry in self.symmetries + ] + + @property + def symmetry_names(self) -> list[str, str]: + """Symmetry names as a list of 2-tuples with string node names.""" + return [ + (self.nodes[i].name, self.nodes[j].name) for (i, j) in self.symmetry_inds + ] + + def get_flipped_node_inds(self) -> list[int]: + """Returns node indices that should be switched when horizontally flipping. + + This is useful as a lookup table for flipping the landmark coordinates when + doing data augmentation. + + Example: + >>> skel = Skeleton(["A", "B_left", "B_right", "C", "D_left", "D_right"]) + >>> skel.add_symmetry("B_left", "B_right") + >>> skel.add_symmetry("D_left", "D_right") + >>> skel.flipped_node_inds + [0, 2, 1, 3, 5, 4] + >>> pose = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]]) + >>> pose[skel.flipped_node_inds] + array([[0, 0], + [2, 2], + [1, 1], + [3, 3], + [5, 5], + [4, 4]]) + """ flip_idx = np.arange(len(self.nodes)) if len(self.symmetries) > 0: symmetry_inds = np.array( @@ -217,11 +250,14 @@ def add_node(self, node: Node | str): Args: node: A `Node` object or a string name to create a new node. """ + node_name = node.name if type(node) == Node else node + if node_name in self._node_name_map: + raise ValueError(f"Node '{node_name}' already exists in the skeleton.") if type(node) == str: node = Node(node) if node not in self.nodes: self.nodes.append(node) - self._update_node_map(None, self.nodes) + self._update_node_map(None, self.nodes) def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None): """Add an `Edge` to the skeleton. @@ -268,9 +304,10 @@ def add_symmetry( node2: The second `Node` or name of the second node. """ if type(node1) == Symmetry: - if node1 not in self.symmetries: - self.symmetries.append(node1) - for node in node1.nodes: + symmetry = node1 + if symmetry not in self.symmetries: + self.symmetries.append(symmetry) + for node in symmetry.nodes: if node not in self.nodes: self.add_node(node) return diff --git a/tests/io/test_slp.py b/tests/io/test_slp.py index e837e5df..ab65fba8 100644 --- a/tests/io/test_slp.py +++ b/tests/io/test_slp.py @@ -92,8 +92,8 @@ def test_read_skeleton(centered_pair): assert len(skeleton.nodes) == 24 assert len(skeleton.edges) == 23 assert len(skeleton.symmetries) == 20 - assert Node("wingR") in skeleton.symmetries[0].nodes - assert Node("wingL") in skeleton.symmetries[0].nodes + assert "wingR" in skeleton.symmetry_names[0] + assert "wingL" in skeleton.symmetry_names[0] def test_read_videos_pkg(slp_minimal_pkg): @@ -157,7 +157,10 @@ def test_write_metadata(centered_pair, tmp_path): assert saved_skeletons[0].name == labels.skeletons[0].name assert saved_skeletons[0].node_names == labels.skeletons[0].node_names assert saved_skeletons[0].edge_inds == labels.skeletons[0].edge_inds - assert saved_skeletons[0].flipped_node_inds == labels.skeletons[0].flipped_node_inds + assert ( + saved_skeletons[0].get_flipped_node_inds() + == labels.skeletons[0].get_flipped_node_inds() + ) def test_write_lfs(centered_pair, slp_real_data, tmp_path): @@ -224,8 +227,8 @@ def test_load_multi_skeleton(tmpdir): assert loaded_skels[1].node_names == ["n3", "n4"] assert loaded_skels[0].edge_inds == [(0, 1)] assert loaded_skels[1].edge_inds == [(0, 1)] - assert loaded_skels[0].flipped_node_inds == [1, 0] - assert loaded_skels[1].flipped_node_inds == [1, 0] + assert loaded_skels[0].get_flipped_node_inds() == [1, 0] + assert loaded_skels[1].get_flipped_node_inds() == [1, 0] def test_slp_imgvideo(tmpdir, slp_imgvideo): diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index fa080bd1..2bc975e7 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -20,14 +20,15 @@ def test_labels(): """Test methods in the `Labels` data structure.""" + skel = Skeleton(["A", "B"]) labels = Labels( [ LabeledFrame( video=Video(filename="test"), frame_idx=0, instances=[ - Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"])), - PredictedInstance([[4, 5], [6, 7]], skeleton=Skeleton(["A", "B"])), + Instance([[0, 1], [2, 3]], skeleton=skel), + PredictedInstance([[4, 5], [6, 7]], skeleton=skel), ], ) ] diff --git a/tests/model/test_skeleton.py b/tests/model/test_skeleton.py index 0d861dd0..10bb37e2 100644 --- a/tests/model/test_skeleton.py +++ b/tests/model/test_skeleton.py @@ -75,18 +75,16 @@ def test_skeleton_node_map(): assert skel.index("B") == 0 -def test_flipped_node_inds(): +def test_get_flipped_node_inds(): skel = Skeleton(["A", "BL", "BR", "C", "DL", "DR"]) - assert skel.flipped_node_inds == [0, 1, 2, 3, 4, 5] + assert skel.get_flipped_node_inds() == [0, 1, 2, 3, 4, 5] - skel.symmetries = [ - Symmetry([Node("BL"), Node("BR")]), - Symmetry([Node("DL"), Node("DR")]), - ] - assert skel.flipped_node_inds == [0, 2, 1, 3, 5, 4] + skel.add_symmetry("BL", "BR") + skel.add_symmetry("DL", "DR") + assert skel.get_flipped_node_inds() == [0, 2, 1, 3, 5, 4] - assert skel.symmetries[0][0] in (Node("BL"), Node("BR")) - assert skel.symmetries[0][1] in (Node("BL"), Node("BR")) + assert skel.symmetries[0][0].name in ("BL", "BR") + assert skel.symmetries[0][1].name in ("BL", "BR") syms = list(skel.symmetries[0]) assert syms[0] != syms[1] @@ -113,8 +111,9 @@ def test_add_node(): skel.add_node("C") assert skel.node_names == ["A", "B", "C"] - skel.add_node("B") - assert skel.node_names == ["A", "B", "C"] + with pytest.raises(ValueError): + skel.add_node("B") + assert skel.node_names == ["A", "B", "C"] def test_add_edge(): @@ -135,31 +134,22 @@ def test_add_edge(): skel.add_edge("D", "A") assert skel.edge_inds == [(0, 1), (1, 0), (0, 2), (3, 0)] - skel = Skeleton(["A", "B"]) - skel.add_edge(Edge(Node("A"), Node("B"))) - assert skel.edge_inds == [(0, 1)] - - skel.add_edge(Edge(Node("C"), Node("D"))) - assert skel.edge_inds == [(0, 1), (2, 3)] - def test_add_symmetry(): skel = Skeleton(["A", "B"]) skel.add_symmetry("A", "B") - assert skel.symmetries == [Symmetry([Node("A"), Node("B")])] + assert skel.symmetry_inds == [(0, 1)] + assert skel.symmetry_names == [("A", "B")] + # Don't duplicate reversed symmetries skel.add_symmetry("B", "A") - assert skel.symmetries == [Symmetry([Node("A"), Node("B")])] + assert skel.symmetry_inds == [(0, 1)] + assert skel.symmetry_names == [("A", "B")] + # Add new symmetry with new node objects skel.add_symmetry(Symmetry([Node("C"), Node("D")])) - assert skel.symmetries == [ - Symmetry([Node("A"), Node("B")]), - Symmetry([Node("C"), Node("D")]), - ] + assert skel.symmetry_inds == [(0, 1), (2, 3)] + # Add new symmetry with node names skel.add_symmetry("E", "F") - assert skel.symmetries == [ - Symmetry([Node("A"), Node("B")]), - Symmetry([Node("C"), Node("D")]), - Symmetry([Node("E"), Node("F")]), - ] + assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)]