Skip to content

Commit

Permalink
Add skeleton utilities (#76)
Browse files Browse the repository at this point in the history
* Add skeleton utilities

* Coverage and bugs

* Cover!
  • Loading branch information
talmo authored Apr 14, 2024
1 parent 65b5ac4 commit 5d98c96
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 2 deletions.
86 changes: 84 additions & 2 deletions sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __len__(self) -> int:
"""Return the number of nodes in the skeleton."""
return len(self.nodes)

def index(self, node: Union[Node, str]) -> int:
def index(self, node: Node | str) -> int:
"""Return the index of a node specified as a `Node` or string name."""
if type(node) == str:
return self.index(self._node_name_map[node])
Expand All @@ -192,11 +192,93 @@ def index(self, node: Union[Node, str]) -> int:
else:
raise IndexError(f"Invalid indexing argument for skeleton: {node}")

def __getitem__(self, idx: Union[int, str]) -> Node:
def __getitem__(self, idx: int | str) -> Node:
"""Return a `Node` when indexing by name or integer."""
if type(idx) == int:
return self.nodes[idx]
elif type(idx) == str:
return self._node_name_map[idx]
else:
raise IndexError(f"Invalid indexing argument for skeleton: {idx}")

def add_node(self, node: Node | str):
"""Add a `Node` to the skeleton.
Args:
node: A `Node` object or a string name to create a new node.
"""
if type(node) == str:
node = Node(node)
if node not in self.nodes:
self.nodes.append(node)
self._update_node_map(None, self.nodes)

def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None):
"""Add an `Edge` to the skeleton.
Args:
src: The source `Node` or name of the source node.
dst: The destination `Node` or name of the destination node.
"""
if type(src) == Edge:
edge = src
if edge not in self.edges:
self.edges.append(edge)
if edge.source not in self.nodes:
self.add_node(edge.source)
if edge.destination not in self.nodes:
self.add_node(edge.destination)
return

if type(src) == str or type(src) == Node:
try:
src = self.index(src)
except KeyError:
self.add_node(src)
src = self.index(src)

if type(dst) == str or type(dst) == Node:
try:
dst = self.index(dst)
except KeyError:
self.add_node(dst)
dst = self.index(dst)

edge = Edge(self.nodes[src], self.nodes[dst])
if edge not in self.edges:
self.edges.append(edge)

def add_symmetry(
self, node1: Symmetry | Node | str = None, node2: Node | str = None
):
"""Add a symmetry relationship to the skeleton.
Args:
node1: The first `Node` or name of the first node.
node2: The second `Node` or name of the second node.
"""
if type(node1) == Symmetry:
if node1 not in self.symmetries:
self.symmetries.append(node1)
for node in node1.nodes:
if node not in self.nodes:
self.add_node(node)
return

if type(node1) == str or type(node1) == Node:
try:
node1 = self.index(node1)
except KeyError:
self.add_node(node1)
node1 = self.index(node1)

if type(node2) == str or type(node2) == Node:
try:
node2 = self.index(node2)
except KeyError:
self.add_node(node2)
node2 = self.index(node2)

symmetry = Symmetry({self.nodes[node1], self.nodes[node2]})
if symmetry not in self.symmetries:
self.symmetries.append(symmetry)
62 changes: 62 additions & 0 deletions tests/model/test_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,65 @@ def test_edge_unpack():
src, dst = skel.edges[0]
assert src.name == "A"
assert dst.name == "B"


def test_add_node():
skel = Skeleton()
skel.add_node("A")
assert skel.node_names == ["A"]

skel.add_node(Node("B"))
assert skel.node_names == ["A", "B"]

skel.add_node("C")
assert skel.node_names == ["A", "B", "C"]

skel.add_node("B")
assert skel.node_names == ["A", "B", "C"]


def test_add_edge():
skel = Skeleton(["A", "B"])
skel.add_edge("A", "B")
assert skel.edge_inds == [(0, 1)]

skel.add_edge("B", "A")
assert skel.edge_inds == [(0, 1), (1, 0)]

skel.add_edge("A", "B")
assert skel.edge_inds == [(0, 1), (1, 0)]

skel.add_edge("A", "C")
assert skel.edge_inds == [(0, 1), (1, 0), (0, 2)]

skel.add_edge("D", "A")
assert skel.edge_inds == [(0, 1), (1, 0), (0, 2), (3, 0)]

skel = Skeleton(["A", "B"])
skel.add_edge(Edge(Node("A"), Node("B")))
assert skel.edge_inds == [(0, 1)]

skel.add_edge(Edge(Node("C"), Node("D")))
assert skel.edge_inds == [(0, 1), (2, 3)]


def test_add_symmetry():
skel = Skeleton(["A", "B"])
skel.add_symmetry("A", "B")
assert skel.symmetries == [Symmetry([Node("A"), Node("B")])]

skel.add_symmetry("B", "A")
assert skel.symmetries == [Symmetry([Node("A"), Node("B")])]

skel.add_symmetry(Symmetry([Node("C"), Node("D")]))
assert skel.symmetries == [
Symmetry([Node("A"), Node("B")]),
Symmetry([Node("C"), Node("D")]),
]

skel.add_symmetry("E", "F")
assert skel.symmetries == [
Symmetry([Node("A"), Node("B")]),
Symmetry([Node("C"), Node("D")]),
Symmetry([Node("E"), Node("F")]),
]

0 comments on commit 5d98c96

Please sign in to comment.