diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index dbc190d8..458cca42 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -183,7 +183,7 @@ def __len__(self) -> int: """Return the number of nodes in the skeleton.""" return len(self.nodes) - def index(self, node: Union[Node, str]) -> int: + def index(self, node: Node | str) -> int: """Return the index of a node specified as a `Node` or string name.""" if type(node) == str: return self.index(self._node_name_map[node]) @@ -192,7 +192,7 @@ def index(self, node: Union[Node, str]) -> int: else: raise IndexError(f"Invalid indexing argument for skeleton: {node}") - def __getitem__(self, idx: Union[int, str]) -> Node: + def __getitem__(self, idx: int | str) -> Node: """Return a `Node` when indexing by name or integer.""" if type(idx) == int: return self.nodes[idx] @@ -200,3 +200,85 @@ def __getitem__(self, idx: Union[int, str]) -> Node: return self._node_name_map[idx] else: raise IndexError(f"Invalid indexing argument for skeleton: {idx}") + + def add_node(self, node: Node | str): + """Add a `Node` to the skeleton. + + Args: + node: A `Node` object or a string name to create a new node. + """ + if type(node) == str: + node = Node(node) + if node not in self.nodes: + self.nodes.append(node) + self._update_node_map(None, self.nodes) + + def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None): + """Add an `Edge` to the skeleton. + + Args: + src: The source `Node` or name of the source node. + dst: The destination `Node` or name of the destination node. + """ + if type(src) == Edge: + edge = src + if edge not in self.edges: + self.edges.append(edge) + if edge.source not in self.nodes: + self.add_node(edge.source) + if edge.destination not in self.nodes: + self.add_node(edge.destination) + return + + if type(src) == str or type(src) == Node: + try: + src = self.index(src) + except KeyError: + self.add_node(src) + src = self.index(src) + + if type(dst) == str or type(dst) == Node: + try: + dst = self.index(dst) + except KeyError: + self.add_node(dst) + dst = self.index(dst) + + edge = Edge(self.nodes[src], self.nodes[dst]) + if edge not in self.edges: + self.edges.append(edge) + + def add_symmetry( + self, node1: Symmetry | Node | str = None, node2: Node | str = None + ): + """Add a symmetry relationship to the skeleton. + + Args: + node1: The first `Node` or name of the first node. + node2: The second `Node` or name of the second node. + """ + if type(node1) == Symmetry: + if node1 not in self.symmetries: + self.symmetries.append(node1) + for node in node1.nodes: + if node not in self.nodes: + self.add_node(node) + return + + if type(node1) == str or type(node1) == Node: + try: + node1 = self.index(node1) + except KeyError: + self.add_node(node1) + node1 = self.index(node1) + + if type(node2) == str or type(node2) == Node: + try: + node2 = self.index(node2) + except KeyError: + self.add_node(node2) + node2 = self.index(node2) + + symmetry = Symmetry({self.nodes[node1], self.nodes[node2]}) + if symmetry not in self.symmetries: + self.symmetries.append(symmetry) diff --git a/tests/model/test_skeleton.py b/tests/model/test_skeleton.py index 27a61c2c..49d44d96 100644 --- a/tests/model/test_skeleton.py +++ b/tests/model/test_skeleton.py @@ -98,3 +98,65 @@ def test_edge_unpack(): src, dst = skel.edges[0] assert src.name == "A" assert dst.name == "B" + + +def test_add_node(): + skel = Skeleton() + skel.add_node("A") + assert skel.node_names == ["A"] + + skel.add_node(Node("B")) + assert skel.node_names == ["A", "B"] + + skel.add_node("C") + assert skel.node_names == ["A", "B", "C"] + + skel.add_node("B") + assert skel.node_names == ["A", "B", "C"] + + +def test_add_edge(): + skel = Skeleton(["A", "B"]) + skel.add_edge("A", "B") + assert skel.edge_inds == [(0, 1)] + + skel.add_edge("B", "A") + assert skel.edge_inds == [(0, 1), (1, 0)] + + skel.add_edge("A", "B") + assert skel.edge_inds == [(0, 1), (1, 0)] + + skel.add_edge("A", "C") + assert skel.edge_inds == [(0, 1), (1, 0), (0, 2)] + + skel.add_edge("D", "A") + assert skel.edge_inds == [(0, 1), (1, 0), (0, 2), (3, 0)] + + skel = Skeleton(["A", "B"]) + skel.add_edge(Edge(Node("A"), Node("B"))) + assert skel.edge_inds == [(0, 1)] + + skel.add_edge(Edge(Node("C"), Node("D"))) + assert skel.edge_inds == [(0, 1), (2, 3)] + + +def test_add_symmetry(): + skel = Skeleton(["A", "B"]) + skel.add_symmetry("A", "B") + assert skel.symmetries == [Symmetry([Node("A"), Node("B")])] + + skel.add_symmetry("B", "A") + assert skel.symmetries == [Symmetry([Node("A"), Node("B")])] + + skel.add_symmetry(Symmetry([Node("C"), Node("D")])) + assert skel.symmetries == [ + Symmetry([Node("A"), Node("B")]), + Symmetry([Node("C"), Node("D")]), + ] + + skel.add_symmetry("E", "F") + assert skel.symmetries == [ + Symmetry([Node("A"), Node("B")]), + Symmetry([Node("C"), Node("D")]), + Symmetry([Node("E"), Node("F")]), + ]