Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make skeleton nodes mutable #135

Merged
merged 1 commit into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np


@define(frozen=True, cache_hash=True)
@define(eq=False)
class Node:
"""A landmark type within a `Skeleton`.

Expand Down Expand Up @@ -171,8 +171,41 @@ def edge_names(self) -> list[str, str]:
return [(edge.source.name, edge.destination.name) for edge in self.edges]

@property
def flipped_node_inds(self) -> list[int]:
"""Returns node indices that should be switched when horizontally flipping."""
def symmetry_inds(self) -> list[Tuple[int, int]]:
"""Symmetry indices as a list of 2-tuples."""
return [
tuple(sorted((self.index(symmetry[0]), self.index(symmetry[1]))))
for symmetry in self.symmetries
]

@property
def symmetry_names(self) -> list[str, str]:
"""Symmetry names as a list of 2-tuples with string node names."""
return [
(self.nodes[i].name, self.nodes[j].name) for (i, j) in self.symmetry_inds
]

def get_flipped_node_inds(self) -> list[int]:
"""Returns node indices that should be switched when horizontally flipping.

This is useful as a lookup table for flipping the landmark coordinates when
doing data augmentation.

Example:
>>> skel = Skeleton(["A", "B_left", "B_right", "C", "D_left", "D_right"])
>>> skel.add_symmetry("B_left", "B_right")
>>> skel.add_symmetry("D_left", "D_right")
>>> skel.flipped_node_inds
[0, 2, 1, 3, 5, 4]
>>> pose = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
>>> pose[skel.flipped_node_inds]
array([[0, 0],
[2, 2],
[1, 1],
[3, 3],
[5, 5],
[4, 4]])
"""
flip_idx = np.arange(len(self.nodes))
if len(self.symmetries) > 0:
symmetry_inds = np.array(
Expand Down Expand Up @@ -217,11 +250,14 @@ def add_node(self, node: Node | str):
Args:
node: A `Node` object or a string name to create a new node.
"""
node_name = node.name if type(node) == Node else node
if node_name in self._node_name_map:
raise ValueError(f"Node '{node_name}' already exists in the skeleton.")
if type(node) == str:
node = Node(node)
if node not in self.nodes:
self.nodes.append(node)
self._update_node_map(None, self.nodes)
self._update_node_map(None, self.nodes)

def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None):
"""Add an `Edge` to the skeleton.
Expand Down Expand Up @@ -268,9 +304,10 @@ def add_symmetry(
node2: The second `Node` or name of the second node.
"""
if type(node1) == Symmetry:
if node1 not in self.symmetries:
self.symmetries.append(node1)
for node in node1.nodes:
symmetry = node1
if symmetry not in self.symmetries:
self.symmetries.append(symmetry)
for node in symmetry.nodes:
if node not in self.nodes:
self.add_node(node)
return
Expand Down
13 changes: 8 additions & 5 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ def test_read_skeleton(centered_pair):
assert len(skeleton.nodes) == 24
assert len(skeleton.edges) == 23
assert len(skeleton.symmetries) == 20
assert Node("wingR") in skeleton.symmetries[0].nodes
assert Node("wingL") in skeleton.symmetries[0].nodes
assert "wingR" in skeleton.symmetry_names[0]
assert "wingL" in skeleton.symmetry_names[0]


def test_read_videos_pkg(slp_minimal_pkg):
Expand Down Expand Up @@ -157,7 +157,10 @@ def test_write_metadata(centered_pair, tmp_path):
assert saved_skeletons[0].name == labels.skeletons[0].name
assert saved_skeletons[0].node_names == labels.skeletons[0].node_names
assert saved_skeletons[0].edge_inds == labels.skeletons[0].edge_inds
assert saved_skeletons[0].flipped_node_inds == labels.skeletons[0].flipped_node_inds
assert (
saved_skeletons[0].get_flipped_node_inds()
== labels.skeletons[0].get_flipped_node_inds()
)


def test_write_lfs(centered_pair, slp_real_data, tmp_path):
Expand Down Expand Up @@ -224,8 +227,8 @@ def test_load_multi_skeleton(tmpdir):
assert loaded_skels[1].node_names == ["n3", "n4"]
assert loaded_skels[0].edge_inds == [(0, 1)]
assert loaded_skels[1].edge_inds == [(0, 1)]
assert loaded_skels[0].flipped_node_inds == [1, 0]
assert loaded_skels[1].flipped_node_inds == [1, 0]
assert loaded_skels[0].get_flipped_node_inds() == [1, 0]
assert loaded_skels[1].get_flipped_node_inds() == [1, 0]


def test_slp_imgvideo(tmpdir, slp_imgvideo):
Expand Down
5 changes: 3 additions & 2 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@

def test_labels():
"""Test methods in the `Labels` data structure."""
skel = Skeleton(["A", "B"])
labels = Labels(
[
LabeledFrame(
video=Video(filename="test"),
frame_idx=0,
instances=[
Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"])),
PredictedInstance([[4, 5], [6, 7]], skeleton=Skeleton(["A", "B"])),
Instance([[0, 1], [2, 3]], skeleton=skel),
PredictedInstance([[4, 5], [6, 7]], skeleton=skel),
],
)
]
Expand Down
48 changes: 19 additions & 29 deletions tests/model/test_skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,16 @@ def test_skeleton_node_map():
assert skel.index("B") == 0


def test_flipped_node_inds():
def test_get_flipped_node_inds():
skel = Skeleton(["A", "BL", "BR", "C", "DL", "DR"])
assert skel.flipped_node_inds == [0, 1, 2, 3, 4, 5]
assert skel.get_flipped_node_inds() == [0, 1, 2, 3, 4, 5]

skel.symmetries = [
Symmetry([Node("BL"), Node("BR")]),
Symmetry([Node("DL"), Node("DR")]),
]
assert skel.flipped_node_inds == [0, 2, 1, 3, 5, 4]
skel.add_symmetry("BL", "BR")
skel.add_symmetry("DL", "DR")
assert skel.get_flipped_node_inds() == [0, 2, 1, 3, 5, 4]

assert skel.symmetries[0][0] in (Node("BL"), Node("BR"))
assert skel.symmetries[0][1] in (Node("BL"), Node("BR"))
assert skel.symmetries[0][0].name in ("BL", "BR")
assert skel.symmetries[0][1].name in ("BL", "BR")
syms = list(skel.symmetries[0])
assert syms[0] != syms[1]

Expand All @@ -113,8 +111,9 @@ def test_add_node():
skel.add_node("C")
assert skel.node_names == ["A", "B", "C"]

skel.add_node("B")
assert skel.node_names == ["A", "B", "C"]
with pytest.raises(ValueError):
skel.add_node("B")
assert skel.node_names == ["A", "B", "C"]
Comment on lines +114 to +116
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unreachable assertion in error handling test.

The assertion after raise ValueError is unreachable. The test should be restructured to properly verify the error case.

Apply this change:

     with pytest.raises(ValueError):
         skel.add_node("B")
-        assert skel.node_names == ["A", "B", "C"]
+    # Verify state remains unchanged after failed operation
+    assert skel.node_names == ["A", "B", "C"]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
with pytest.raises(ValueError):
skel.add_node("B")
assert skel.node_names == ["A", "B", "C"]
with pytest.raises(ValueError):
skel.add_node("B")
# Verify state remains unchanged after failed operation
assert skel.node_names == ["A", "B", "C"]



def test_add_edge():
Expand All @@ -135,31 +134,22 @@ def test_add_edge():
skel.add_edge("D", "A")
assert skel.edge_inds == [(0, 1), (1, 0), (0, 2), (3, 0)]

skel = Skeleton(["A", "B"])
skel.add_edge(Edge(Node("A"), Node("B")))
assert skel.edge_inds == [(0, 1)]

skel.add_edge(Edge(Node("C"), Node("D")))
assert skel.edge_inds == [(0, 1), (2, 3)]


def test_add_symmetry():
skel = Skeleton(["A", "B"])
skel.add_symmetry("A", "B")
assert skel.symmetries == [Symmetry([Node("A"), Node("B")])]
assert skel.symmetry_inds == [(0, 1)]
assert skel.symmetry_names == [("A", "B")]

# Don't duplicate reversed symmetries
skel.add_symmetry("B", "A")
assert skel.symmetries == [Symmetry([Node("A"), Node("B")])]
assert skel.symmetry_inds == [(0, 1)]
assert skel.symmetry_names == [("A", "B")]

# Add new symmetry with new node objects
skel.add_symmetry(Symmetry([Node("C"), Node("D")]))
assert skel.symmetries == [
Symmetry([Node("A"), Node("B")]),
Symmetry([Node("C"), Node("D")]),
]
assert skel.symmetry_inds == [(0, 1), (2, 3)]

# Add new symmetry with node names
skel.add_symmetry("E", "F")
assert skel.symmetries == [
Symmetry([Node("A"), Node("B")]),
Symmetry([Node("C"), Node("D")]),
Symmetry([Node("E"), Node("F")]),
]
assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)]
Comment on lines +141 to +155
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding mutability verification.

The test thoroughly covers symmetry addition scenarios, but given the PR's focus on node mutability, consider adding a test case that verifies nodes remain mutable after being added to symmetries.

Add this test case:

     # Add new symmetry with node names
     skel.add_symmetry("E", "F")
     assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)]
+
+    # Verify nodes remain mutable after symmetry addition
+    node = skel["E"]
+    original_hash = hash(node)
+    node.x = 42
+    assert node.x == 42
+    assert hash(node) == original_hash  # Hash should be stable
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
assert skel.symmetry_inds == [(0, 1)]
assert skel.symmetry_names == [("A", "B")]
# Don't duplicate reversed symmetries
skel.add_symmetry("B", "A")
assert skel.symmetries == [Symmetry([Node("A"), Node("B")])]
assert skel.symmetry_inds == [(0, 1)]
assert skel.symmetry_names == [("A", "B")]
# Add new symmetry with new node objects
skel.add_symmetry(Symmetry([Node("C"), Node("D")]))
assert skel.symmetries == [
Symmetry([Node("A"), Node("B")]),
Symmetry([Node("C"), Node("D")]),
]
assert skel.symmetry_inds == [(0, 1), (2, 3)]
# Add new symmetry with node names
skel.add_symmetry("E", "F")
assert skel.symmetries == [
Symmetry([Node("A"), Node("B")]),
Symmetry([Node("C"), Node("D")]),
Symmetry([Node("E"), Node("F")]),
]
assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)]
assert skel.symmetry_inds == [(0, 1)]
assert skel.symmetry_names == [("A", "B")]
# Don't duplicate reversed symmetries
skel.add_symmetry("B", "A")
assert skel.symmetry_inds == [(0, 1)]
assert skel.symmetry_names == [("A", "B")]
# Add new symmetry with new node objects
skel.add_symmetry(Symmetry([Node("C"), Node("D")]))
assert skel.symmetry_inds == [(0, 1), (2, 3)]
# Add new symmetry with node names
skel.add_symmetry("E", "F")
assert skel.symmetry_inds == [(0, 1), (2, 3), (4, 5)]
# Verify nodes remain mutable after symmetry addition
node = skel["E"]
original_hash = hash(node)
node.x = 42
assert node.x == 42
assert hash(node) == original_hash # Hash should be stable

Loading