Skip to content

Commit

Permalink
Add Labels.replace_skeleton() and Instance.replace_skeleton()
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Nov 1, 2024
1 parent 17950c4 commit 897752b
Show file tree
Hide file tree
Showing 4 changed files with 180 additions and 2 deletions.
56 changes: 54 additions & 2 deletions sleap_io/model/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from attrs import define, validators, field, cmp_using
from typing import ClassVar, Optional, Union, cast
from sleap_io import Skeleton, Node
from sleap_io.model.skeleton import NodeOrIndex
import numpy as np
import math

Expand Down Expand Up @@ -323,8 +324,6 @@ def update_skeleton(self):
However, it is recommended to use `Labels`-level methods (e.g.,
`Labels.remove_nodes()`) when manipulating the skeleton as these will
automatically call this method on every instance.
It is NOT necessary to call this method after renaming nodes.
"""
# Create a new dictionary to hold the updated points
new_points = {}
Expand All @@ -341,6 +340,59 @@ def update_skeleton(self):
# Update the points dictionary
self.points = new_points

def replace_skeleton(
self,
new_skeleton: Skeleton,
node_map: dict[NodeOrIndex, NodeOrIndex] | None = None,
rev_node_map: dict[NodeOrIndex, NodeOrIndex] | None = None,
):
"""Replace the skeleton associated with the instance.
The points dictionary will be updated to match the new skeleton.
Args:
new_skeleton: The new `Skeleton` to associate with the instance.
node_map: Dictionary mapping nodes in the old skeleton to nodes in the new
skeleton. Keys and values can be specified as `Node` objects, integer
indices, or string names. If not provided, only nodes with identical
names will be mapped. Points associated with unmapped nodes will be
removed.
rev_node_map: Dictionary mapping nodes in the new skeleton to nodes in the
old skeleton. This is used internally when calling from
`Labels.replace_skeleton()` as it is more efficient to compute this
mapping once and pass it to all instances. No validation is done on this
mapping, so nodes are expected to be `Node` objects.
"""
if rev_node_map is None:
if node_map is None:
node_map = {}
for old_node in self.skeleton.nodes:
for new_node in new_skeleton.nodes:
if old_node.name == new_node.name:
node_map[old_node] = new_node
break
else:
node_map = {
self.skeleton.require_node(
old, add_missing=False
): new_skeleton.require_node(new, add_missing=False)
for old, new in node_map.items()
}

# Make new -> old mapping for nodes
rev_node_map = {new: old for old, new in node_map.items()}

# Build new points list with mapped nodes
new_points = {}
for new_node in new_skeleton.nodes:
old_node = rev_node_map.get(new_node, None)
if old_node is not None and old_node in self.points:
new_points[new_node] = self.points[old_node]

# Update the skeleton and points
self.skeleton = new_skeleton
self.points = new_points


@define
class PredictedInstance(Instance):
Expand Down
62 changes: 62 additions & 0 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,68 @@ def reorder_nodes(
if inst.skeleton == skeleton:
inst.update_skeleton()

def replace_skeleton(
self,
new_skeleton: Skeleton,
old_skeleton: Skeleton | None = None,
node_map: dict[NodeOrIndex, NodeOrIndex] | None = None,
):
"""Replace the skeleton in the labels.
Args:
new_skeleton: The new `Skeleton` to replace the old skeleton with.
old_skeleton: The old `Skeleton` to replace. If `None` (the default),
assumes there is only one skeleton in the labels and raises `ValueError`
otherwise.
node_map: Dictionary mapping nodes in the old skeleton to nodes in the new
skeleton. Keys and values can be specified as `Node` objects, integer
indices, or string names. If not provided, only nodes with identical
names will be mapped. Points associated with unmapped nodes will be
removed.
Raises:
ValueError: If there is more than one skeleton in the `Labels` but it is not
specified.
Warning:
This method will replace the skeleton in all instances in the labels that
have the old skeleton. **All point data associated with nodes not in the
`node_map` will be lost.**
"""
if old_skeleton is None:
if len(self.skeletons) != 1:
raise ValueError(

Check warning on line 616 in sleap_io/model/labels.py

View check run for this annotation

Codecov / codecov/patch

sleap_io/model/labels.py#L616

Added line #L616 was not covered by tests
"Old skeleton must be specified when there is more than one "
"skeleton in the labels."
)
old_skeleton = self.skeleton

if node_map is None:
node_map = {}
for old_node in old_skeleton.nodes:
for new_node in new_skeleton.nodes:
if old_node.name == new_node.name:
node_map[old_node] = new_node
break
else:
node_map = {
old_skeleton.require_node(
old, add_missing=False
): new_skeleton.require_node(new, add_missing=False)
for old, new in node_map.items()
}

# Make new -> old mapping for nodes for efficiency.
rev_node_map = {new: old for old, new in node_map.items()}

# Replace the skeleton in the instances.
for inst in self.instances:
if inst.skeleton == old_skeleton:
inst.replace_skeleton(new_skeleton, rev_node_map=rev_node_map)

# Replace the skeleton in the labels.
self.skeletons[self.skeletons.index(old_skeleton)] = new_skeleton

def replace_videos(
self,
old_videos: list[Video] | None = None,
Expand Down
33 changes: 33 additions & 0 deletions tests/model/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,36 @@ def test_instance_update_skeleton():
) # but the points dict still has the old order
inst.update_skeleton()
assert list(inst.points.keys()) == skel.nodes # after update, the order is correct


def test_instance_replace_skeleton():
# Full replacement
old_skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel)
new_skel = Skeleton(["X", "Y", "Z"])
inst.replace_skeleton(new_skel, node_map={"A": "X", "B": "Y", "C": "Z"})
assert inst.skeleton == new_skel
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]])
assert list(inst.points.keys()) == new_skel.nodes

# Partial replacement
old_skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel)
new_skel = Skeleton(["X", "C", "Y"])
inst.replace_skeleton(new_skel)
assert inst.skeleton == new_skel
assert_equal(inst.numpy(), [[np.nan, np.nan], [2, 2], [np.nan, np.nan]])
assert new_skel["C"] in inst.points
assert old_skel["A"] not in inst.points
assert old_skel["C"] not in inst.points

# Fast path with reverse node map
old_skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel)
new_skel = Skeleton(["X", "Y", "Z"])
rev_node_map = {
new_node: old_node for new_node, old_node in zip(new_skel.nodes, old_skel.nodes)
}
inst.replace_skeleton(new_skel, rev_node_map=rev_node_map)
assert inst.skeleton == new_skel
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]])
31 changes: 31 additions & 0 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,3 +716,34 @@ def test_labels_reorder_nodes(slp_real_data):
labels.skeletons.append(Skeleton())
with pytest.raises(ValueError):
labels.reorder_nodes(["head", "abdomen"])


def test_labels_replace_skeleton(slp_real_data):
labels = load_slp(slp_real_data)
assert labels.skeleton.node_names == ["head", "abdomen"]
inst = labels[0][0]
assert_allclose(inst.numpy(), [[91.886988, 204.018843], [151.536969, 159.825034]])

# Replace with full mapping
new_skel = Skeleton(["ABDOMEN", "HEAD"])
labels.replace_skeleton(new_skel, node_map={"abdomen": "ABDOMEN", "head": "HEAD"})
assert labels.skeleton == new_skel
inst = labels[0][0]
assert inst.skeleton == new_skel
assert_allclose(inst.numpy(), [[151.536969, 159.825034], [91.886988, 204.018843]])

# Replace with partial (inferred) mapping
new_skel = Skeleton(["x", "ABDOMEN"])
labels.replace_skeleton(new_skel)
assert labels.skeleton == new_skel
inst = labels[0][0]
assert inst.skeleton == new_skel
assert_allclose(inst.numpy(), [[np.nan, np.nan], [151.536969, 159.825034]])

# Replace with no mapping
new_skel = Skeleton(["front", "back"])
labels.replace_skeleton(new_skel)
assert labels.skeleton == new_skel
inst = labels[0][0]
assert inst.skeleton == new_skel
assert_allclose(inst.numpy(), [[np.nan, np.nan], [np.nan, np.nan]])

0 comments on commit 897752b

Please sign in to comment.