Skip to content

Commit

Permalink
Add Labels-level skeleton manipulation with instance updating
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Nov 1, 2024
1 parent e50025d commit 17950c4
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 13 deletions.
121 changes: 120 additions & 1 deletion sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,19 @@

from __future__ import annotations
from sleap_io import (
Skeleton,
LabeledFrame,
Instance,
PredictedInstance,
Video,
Track,
SuggestionFrame,
)
from sleap_io.model.skeleton import NodeOrIndex
from attrs import define, field
from typing import Iterator, Union, Optional, Any
import numpy as np
from pathlib import Path
from sleap_io.model.skeleton import Skeleton
from copy import deepcopy


Expand Down Expand Up @@ -464,6 +465,124 @@ def instances(self) -> Iterator[Instance]:
"""Return an iterator over all instances within all labeled frames."""
return (instance for lf in self.labeled_frames for instance in lf.instances)

def rename_nodes(
self,
name_map: dict[NodeOrIndex, str] | list[str],
skeleton: Skeleton | None = None,
):
"""Rename nodes in the skeleton.
Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.
If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.
skeleton: `Skeleton` to update. If `None` (the default), assumes there is
only one skeleton in the labels and raises `ValueError` otherwise.
Raises:
ValueError: If the new node names exist in the skeleton, if the old node
names are not found in the skeleton, or if there is more than one
skeleton in the `Labels` but it is not specified.
Notes:
This method is recommended over `Skeleton.rename_nodes` as it will update
all instances in the labels to reflect the new node names.
Example:
>>> labels = Labels(skeletons=[Skeleton(["A", "B", "C"])])
>>> labels.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> labels.skeleton.node_names
["X", "Y", "Z"]
>>> labels.rename_nodes(["a", "b", "c"])
>>> labels.skeleton.node_names
["a", "b", "c"]
"""
if skeleton is None:
if len(self.skeletons) != 1:
raise ValueError(
"Skeleton must be specified when there is more than one skeleton in "
"the labels."
)
skeleton = self.skeleton

skeleton.rename_nodes(name_map)

for inst in self.instances:
if inst.skeleton == skeleton:
inst.update_skeleton()

def remove_nodes(self, nodes: list[NodeOrIndex], skeleton: Skeleton | None = None):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
skeleton: `Skeleton` to update. If `None` (the default), assumes there is
only one skeleton in the labels and raises `ValueError` otherwise.
Raises:
ValueError: If the nodes are not found in the skeleton, or if there is more
than one skeleton in the `Labels` but it is not specified.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name,
and updating instances to reflect the changes made to the skeleton.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
"""
if skeleton is None:
if len(self.skeletons) != 1:
raise ValueError(
"Skeleton must be specified when there is more than one skeleton "
"in the labels."
)
skeleton = self.skeleton

skeleton.remove_nodes(nodes)

for inst in self.instances:
if inst.skeleton == skeleton:
inst.update_skeleton()

def reorder_nodes(
self, new_order: list[NodeOrIndex], skeleton: Skeleton | None = None
):
"""Reorder nodes in the skeleton.
Args:
new_order: A list of node names, indices, or `Node` objects specifying the
new order of the nodes.
skeleton: `Skeleton` to update. If `None` (the default), assumes there is
only one skeleton in the labels and raises `ValueError` otherwise.
Raises:
ValueError: If the new order of nodes is not the same length as the current
nodes, or if there is more than one skeleton in the `Labels` but it is
not specified.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name, as well as updating instances to reflect the changes made to the
skeleton.
"""
if skeleton is None:
if len(self.skeletons) != 1:
raise ValueError(
"Skeleton must be specified when there is more than one skeleton "
"in the labels."
)
skeleton = self.skeleton

skeleton.reorder_nodes(new_order)

for inst in self.instances:
if inst.skeleton == skeleton:
inst.update_skeleton()

def replace_videos(
self,
old_videos: list[Video] | None = None,
Expand Down
22 changes: 11 additions & 11 deletions sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __getitem__(self, idx) -> Node:
NodeOrIndex: TypeAlias = Node | str | int


@define
@define(eq=False)
class Skeleton:
"""A description of a set of landmark types and connections between them.
Expand Down Expand Up @@ -420,6 +420,10 @@ def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.
Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.
Expand All @@ -436,10 +440,6 @@ def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]
Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
"""
if type(name_map) == list:
if len(name_map) != len(self.nodes):
Expand Down Expand Up @@ -482,8 +482,8 @@ def remove_nodes(self, nodes: list[NodeOrIndex]):
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Expand Down Expand Up @@ -531,8 +531,8 @@ def remove_node(self, node: NodeOrIndex):
or `Node` object.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed node will also be
removed.
Expand Down Expand Up @@ -561,8 +561,8 @@ def reorder_nodes(self, new_order: list[NodeOrIndex]):
nodes.
Notes:
This method should always be used when reordering nodes in the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
This method handles updating the lookup caches necessary for indexing nodes
by name.
Warning:
After reordering, instances using this skeleton do not need to be updated as
Expand Down
53 changes: 52 additions & 1 deletion tests/model/test_labels.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Test methods and functions in the sleap_io.model.labels file."""

from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_allclose
import pytest
from sleap_io import (
Video,
Expand Down Expand Up @@ -665,3 +665,54 @@ def test_labels_instances():
)
)
assert len(list(labels.instances)) == 3


def test_labels_rename_nodes(slp_real_data):
labels = load_slp(slp_real_data)
assert labels.skeleton.node_names == ["head", "abdomen"]

labels.rename_nodes({"head": "front", "abdomen": "back"})
assert labels.skeleton.node_names == ["front", "back"]

labels.skeletons.append(Skeleton(["A", "B"]))
with pytest.raises(ValueError):
labels.rename_nodes({"A": "a", "B": "b"})
labels.rename_nodes({"A": "a", "B": "b"}, skeleton=labels.skeletons[1])
assert labels.skeletons[1].node_names == ["a", "b"]


def test_labels_remove_nodes(slp_real_data):
labels = load_slp(slp_real_data)
assert labels.skeleton.node_names == ["head", "abdomen"]
assert_allclose(
labels[0][0].numpy(), [[91.886988, 204.018843], [151.536969, 159.825034]]
)

labels.remove_nodes(["head"])
assert labels.skeleton.node_names == ["abdomen"]
assert_allclose(labels[0][0].numpy(), [[151.536969, 159.825034]])

for inst in labels.instances:
assert inst.numpy().shape == (1, 2)

labels.skeletons.append(Skeleton())
with pytest.raises(ValueError):
labels.remove_nodes(["head"])


def test_labels_reorder_nodes(slp_real_data):
labels = load_slp(slp_real_data)
assert labels.skeleton.node_names == ["head", "abdomen"]
assert_allclose(
labels[0][0].numpy(), [[91.886988, 204.018843], [151.536969, 159.825034]]
)

labels.reorder_nodes(["abdomen", "head"])
assert labels.skeleton.node_names == ["abdomen", "head"]
assert_allclose(
labels[0][0].numpy(), [[151.536969, 159.825034], [91.886988, 204.018843]]
)

labels.skeletons.append(Skeleton())
with pytest.raises(ValueError):
labels.reorder_nodes(["head", "abdomen"])

0 comments on commit 17950c4

Please sign in to comment.