diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index 5d6f7268..cec18294 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -13,6 +13,7 @@ from __future__ import annotations from sleap_io import ( + Skeleton, LabeledFrame, Instance, PredictedInstance, @@ -20,11 +21,11 @@ Track, SuggestionFrame, ) +from sleap_io.model.skeleton import NodeOrIndex from attrs import define, field from typing import Iterator, Union, Optional, Any import numpy as np from pathlib import Path -from sleap_io.model.skeleton import Skeleton from copy import deepcopy @@ -464,6 +465,124 @@ def instances(self) -> Iterator[Instance]: """Return an iterator over all instances within all labeled frames.""" return (instance for lf in self.labeled_frames for instance in lf.instances) + def rename_nodes( + self, + name_map: dict[NodeOrIndex, str] | list[str], + skeleton: Skeleton | None = None, + ): + """Rename nodes in the skeleton. + + Args: + name_map: A dictionary mapping old node names to new node names. Keys can be + specified as `Node` objects, integer indices, or string names. Values + must be specified as string names. + + If a list of strings is provided of the same length as the current + nodes, the nodes will be renamed to the names in the list in order. + skeleton: `Skeleton` to update. If `None` (the default), assumes there is + only one skeleton in the labels and raises `ValueError` otherwise. + + Raises: + ValueError: If the new node names exist in the skeleton, if the old node + names are not found in the skeleton, or if there is more than one + skeleton in the `Labels` but it is not specified. + + Notes: + This method is recommended over `Skeleton.rename_nodes` as it will update + all instances in the labels to reflect the new node names. + + Example: + >>> labels = Labels(skeletons=[Skeleton(["A", "B", "C"])]) + >>> labels.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) + >>> labels.skeleton.node_names + ["X", "Y", "Z"] + >>> labels.rename_nodes(["a", "b", "c"]) + >>> labels.skeleton.node_names + ["a", "b", "c"] + """ + if skeleton is None: + if len(self.skeletons) != 1: + raise ValueError( + "Skeleton must be specified when there is more than one skeleton in " + "the labels." + ) + skeleton = self.skeleton + + skeleton.rename_nodes(name_map) + + for inst in self.instances: + if inst.skeleton == skeleton: + inst.update_skeleton() + + def remove_nodes(self, nodes: list[NodeOrIndex], skeleton: Skeleton | None = None): + """Remove nodes from the skeleton. + + Args: + nodes: A list of node names, indices, or `Node` objects to remove. + skeleton: `Skeleton` to update. If `None` (the default), assumes there is + only one skeleton in the labels and raises `ValueError` otherwise. + + Raises: + ValueError: If the nodes are not found in the skeleton, or if there is more + than one skeleton in the `Labels` but it is not specified. + + Notes: + This method should always be used when removing nodes from the skeleton as + it handles updating the lookup caches necessary for indexing nodes by name, + and updating instances to reflect the changes made to the skeleton. + + Any edges and symmetries that are connected to the removed nodes will also + be removed. + """ + if skeleton is None: + if len(self.skeletons) != 1: + raise ValueError( + "Skeleton must be specified when there is more than one skeleton " + "in the labels." + ) + skeleton = self.skeleton + + skeleton.remove_nodes(nodes) + + for inst in self.instances: + if inst.skeleton == skeleton: + inst.update_skeleton() + + def reorder_nodes( + self, new_order: list[NodeOrIndex], skeleton: Skeleton | None = None + ): + """Reorder nodes in the skeleton. + + Args: + new_order: A list of node names, indices, or `Node` objects specifying the + new order of the nodes. + skeleton: `Skeleton` to update. If `None` (the default), assumes there is + only one skeleton in the labels and raises `ValueError` otherwise. + + Raises: + ValueError: If the new order of nodes is not the same length as the current + nodes, or if there is more than one skeleton in the `Labels` but it is + not specified. + + Notes: + This method handles updating the lookup caches necessary for indexing nodes + by name, as well as updating instances to reflect the changes made to the + skeleton. + """ + if skeleton is None: + if len(self.skeletons) != 1: + raise ValueError( + "Skeleton must be specified when there is more than one skeleton " + "in the labels." + ) + skeleton = self.skeleton + + skeleton.reorder_nodes(new_order) + + for inst in self.instances: + if inst.skeleton == skeleton: + inst.update_skeleton() + def replace_videos( self, old_videos: list[Video] | None = None, diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index d17d7844..7d4176ab 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -74,7 +74,7 @@ def __getitem__(self, idx) -> Node: NodeOrIndex: TypeAlias = Node | str | int -@define +@define(eq=False) class Skeleton: """A description of a set of landmark types and connections between them. @@ -420,6 +420,10 @@ def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]): If a list of strings is provided of the same length as the current nodes, the nodes will be renamed to the names in the list in order. + Raises: + ValueError: If the new node names exist in the skeleton or if the old node + names are not found in the skeleton. + Notes: This method should always be used when renaming nodes in the skeleton as it handles updating the lookup caches necessary for indexing nodes by name. @@ -436,10 +440,6 @@ def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]): >>> skel.rename_nodes(["a", "b", "c"]) >>> skel.node_names ["a", "b", "c"] - - Raises: - ValueError: If the new node names exist in the skeleton or if the old node - names are not found in the skeleton. """ if type(name_map) == list: if len(name_map) != len(self.nodes): @@ -482,8 +482,8 @@ def remove_nodes(self, nodes: list[NodeOrIndex]): nodes: A list of node names, indices, or `Node` objects to remove. Notes: - This method should always be used when removing nodes from the skeleton as - it handles updating the lookup caches necessary for indexing nodes by name. + This method handles updating the lookup caches necessary for indexing nodes + by name. Any edges and symmetries that are connected to the removed nodes will also be removed. @@ -531,8 +531,8 @@ def remove_node(self, node: NodeOrIndex): or `Node` object. Notes: - This method should always be used when removing nodes from the skeleton as - it handles updating the lookup caches necessary for indexing nodes by name. + This method handles updating the lookup caches necessary for indexing nodes + by name. Any edges and symmetries that are connected to the removed node will also be removed. @@ -561,8 +561,8 @@ def reorder_nodes(self, new_order: list[NodeOrIndex]): nodes. Notes: - This method should always be used when reordering nodes in the skeleton as - it handles updating the lookup caches necessary for indexing nodes by name. + This method handles updating the lookup caches necessary for indexing nodes + by name. Warning: After reordering, instances using this skeleton do not need to be updated as diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index c7622cb6..1b9251de 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -1,6 +1,6 @@ """Test methods and functions in the sleap_io.model.labels file.""" -from numpy.testing import assert_equal +from numpy.testing import assert_equal, assert_allclose import pytest from sleap_io import ( Video, @@ -665,3 +665,54 @@ def test_labels_instances(): ) ) assert len(list(labels.instances)) == 3 + + +def test_labels_rename_nodes(slp_real_data): + labels = load_slp(slp_real_data) + assert labels.skeleton.node_names == ["head", "abdomen"] + + labels.rename_nodes({"head": "front", "abdomen": "back"}) + assert labels.skeleton.node_names == ["front", "back"] + + labels.skeletons.append(Skeleton(["A", "B"])) + with pytest.raises(ValueError): + labels.rename_nodes({"A": "a", "B": "b"}) + labels.rename_nodes({"A": "a", "B": "b"}, skeleton=labels.skeletons[1]) + assert labels.skeletons[1].node_names == ["a", "b"] + + +def test_labels_remove_nodes(slp_real_data): + labels = load_slp(slp_real_data) + assert labels.skeleton.node_names == ["head", "abdomen"] + assert_allclose( + labels[0][0].numpy(), [[91.886988, 204.018843], [151.536969, 159.825034]] + ) + + labels.remove_nodes(["head"]) + assert labels.skeleton.node_names == ["abdomen"] + assert_allclose(labels[0][0].numpy(), [[151.536969, 159.825034]]) + + for inst in labels.instances: + assert inst.numpy().shape == (1, 2) + + labels.skeletons.append(Skeleton()) + with pytest.raises(ValueError): + labels.remove_nodes(["head"]) + + +def test_labels_reorder_nodes(slp_real_data): + labels = load_slp(slp_real_data) + assert labels.skeleton.node_names == ["head", "abdomen"] + assert_allclose( + labels[0][0].numpy(), [[91.886988, 204.018843], [151.536969, 159.825034]] + ) + + labels.reorder_nodes(["abdomen", "head"]) + assert labels.skeleton.node_names == ["abdomen", "head"] + assert_allclose( + labels[0][0].numpy(), [[151.536969, 159.825034], [91.886988, 204.018843]] + ) + + labels.skeletons.append(Skeleton()) + with pytest.raises(ValueError): + labels.reorder_nodes(["head", "abdomen"])