From 6850eb0ddf0725a04ad733e40b52c26d48203add Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sat, 13 Apr 2024 18:43:42 -0700 Subject: [PATCH 1/3] Add skeleton utilities --- sleap_io/model/skeleton.py | 85 +++++++++++++++++++++++++++++++++++- tests/model/test_skeleton.py | 48 ++++++++++++++++++++ 2 files changed, 131 insertions(+), 2 deletions(-) diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index dbc190d8..b227f828 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -183,7 +183,7 @@ def __len__(self) -> int: """Return the number of nodes in the skeleton.""" return len(self.nodes) - def index(self, node: Union[Node, str]) -> int: + def index(self, node: Node | str) -> int: """Return the index of a node specified as a `Node` or string name.""" if type(node) == str: return self.index(self._node_name_map[node]) @@ -192,7 +192,7 @@ def index(self, node: Union[Node, str]) -> int: else: raise IndexError(f"Invalid indexing argument for skeleton: {node}") - def __getitem__(self, idx: Union[int, str]) -> Node: + def __getitem__(self, idx: int | str) -> Node: """Return a `Node` when indexing by name or integer.""" if type(idx) == int: return self.nodes[idx] @@ -200,3 +200,84 @@ def __getitem__(self, idx: Union[int, str]) -> Node: return self._node_name_map[idx] else: raise IndexError(f"Invalid indexing argument for skeleton: {idx}") + + def add_node(self, node: Node | str): + """Add a `Node` to the skeleton. + + Args: + node: A `Node` object or a string name to create a new node. + """ + if type(node) == str: + node = Node(node) + if node not in self.nodes: + self.nodes.append(node) + self._update_node_map(None, self.nodes) + + def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None): + """Add an `Edge` to the skeleton. + + Args: + src: The source `Node` or name of the source node. + dst: The destination `Node` or name of the destination node. + """ + if type(src) == Edge: + if src not in self.edges: + self.edges.append(src) + if edge.source not in self.nodes: + self.add_node(edge.source) + if edge.destination not in self.nodes: + self.add_node(edge.destination) + return + + if type(src) == str or type(src) == Node: + try: + src = self.index(src) + except KeyError: + self.add_node(src) + src = self.index(src) + + if type(dst) == str or type(dst) == Node: + try: + dst = self.index(dst) + except KeyError: + self.add_node(dst) + dst = self.index(dst) + + edge = Edge(self.nodes[src], self.nodes[dst]) + if edge not in self.edges: + self.edges.append(edge) + + def add_symmetry( + self, node1: Symmetry | Node | str = None, node2: Node | str = None + ): + """Add a symmetry relationship to the skeleton. + + Args: + node1: The first `Node` or name of the first node. + node2: The second `Node` or name of the second node. + """ + if type(node1) == Symmetry: + if node1 not in self.symmetries: + self.symmetries.append(node1) + for node in node1.nodes: + if node not in self.nodes: + self.add_node(node) + return + + if type(node1) == str or type(node1) == Node: + try: + node1 = self.index(node1) + except KeyError: + self.add_node(node1) + node1 = self.index(node1) + + if type(node2) == str or type(node2) == Node: + try: + node2 = self.index(node2) + except KeyError: + self.add_node(node2) + node2 = self.index(node2) + + symmetry = Symmetry({self.nodes[node1], self.nodes[node2]}) + if symmetry not in self.symmetries: + self.symmetries.append(symmetry) diff --git a/tests/model/test_skeleton.py b/tests/model/test_skeleton.py index 27a61c2c..26c0a677 100644 --- a/tests/model/test_skeleton.py +++ b/tests/model/test_skeleton.py @@ -98,3 +98,51 @@ def test_edge_unpack(): src, dst = skel.edges[0] assert src.name == "A" assert dst.name == "B" + + +def test_add_node(): + skel = Skeleton() + skel.add_node("A") + assert skel.node_names == ["A"] + + skel.add_node(Node("B")) + assert skel.node_names == ["A", "B"] + + skel.add_node("C") + assert skel.node_names == ["A", "B", "C"] + + skel.add_node("B") + assert skel.node_names == ["A", "B", "C"] + + +def test_add_edge(): + skel = Skeleton(["A", "B"]) + skel.add_edge("A", "B") + assert skel.edge_inds == [(0, 1)] + + skel.add_edge("B", "A") + assert skel.edge_inds == [(0, 1), (1, 0)] + + skel.add_edge("A", "B") + assert skel.edge_inds == [(0, 1), (1, 0)] + + skel.add_edge("A", "C") + assert skel.edge_inds == [(0, 1), (1, 0), (0, 2)] + + skel.add_edge("D", "A") + assert skel.edge_inds == [(0, 1), (1, 0), (0, 2), (3, 0)] + + +def test_add_symmetry(): + skel = Skeleton(["A", "B"]) + skel.add_symmetry("A", "B") + assert skel.symmetries == [Symmetry([Node("A"), Node("B")])] + + skel.add_symmetry("B", "A") + assert skel.symmetries == [Symmetry([Node("A"), Node("B")])] + + skel.add_symmetry(Symmetry([Node("C"), Node("D")])) + assert skel.symmetries == [ + Symmetry([Node("A"), Node("B")]), + Symmetry([Node("C"), Node("D")]), + ] From 578f6a44a9dffdceed4e13d756278a3a21ee4988 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sat, 13 Apr 2024 18:52:08 -0700 Subject: [PATCH 2/3] Coverage and bugs --- sleap_io/model/skeleton.py | 5 +++-- tests/model/test_skeleton.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index b227f828..458cca42 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -221,8 +221,9 @@ def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None): dst: The destination `Node` or name of the destination node. """ if type(src) == Edge: - if src not in self.edges: - self.edges.append(src) + edge = src + if edge not in self.edges: + self.edges.append(edge) if edge.source not in self.nodes: self.add_node(edge.source) if edge.destination not in self.nodes: diff --git a/tests/model/test_skeleton.py b/tests/model/test_skeleton.py index 26c0a677..433609b5 100644 --- a/tests/model/test_skeleton.py +++ b/tests/model/test_skeleton.py @@ -132,6 +132,10 @@ def test_add_edge(): skel.add_edge("D", "A") assert skel.edge_inds == [(0, 1), (1, 0), (0, 2), (3, 0)] + skel = Skeleton(["A", "B"]) + skel.add_edge(Edge(Node("A"), Node("B"))) + assert skel.edge_inds == [(0, 1)] + def test_add_symmetry(): skel = Skeleton(["A", "B"]) @@ -146,3 +150,10 @@ def test_add_symmetry(): Symmetry([Node("A"), Node("B")]), Symmetry([Node("C"), Node("D")]), ] + + skel.add_symmetry("E", "F") + assert skel.symmetries == [ + Symmetry([Node("A"), Node("B")]), + Symmetry([Node("C"), Node("D")]), + Symmetry([Node("E"), Node("F")]), + ] From 6d14979e173173d5ec79807fd2da6e859553e1fa Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sat, 13 Apr 2024 18:54:24 -0700 Subject: [PATCH 3/3] Cover! --- tests/model/test_skeleton.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/model/test_skeleton.py b/tests/model/test_skeleton.py index 433609b5..49d44d96 100644 --- a/tests/model/test_skeleton.py +++ b/tests/model/test_skeleton.py @@ -136,6 +136,9 @@ def test_add_edge(): skel.add_edge(Edge(Node("A"), Node("B"))) assert skel.edge_inds == [(0, 1)] + skel.add_edge(Edge(Node("C"), Node("D"))) + assert skel.edge_inds == [(0, 1), (2, 3)] + def test_add_symmetry(): skel = Skeleton(["A", "B"])