diff --git a/sleap_io/model/instance.py b/sleap_io/model/instance.py index 25dcd0aa..f019bd20 100644 --- a/sleap_io/model/instance.py +++ b/sleap_io/model/instance.py @@ -308,6 +308,39 @@ def numpy(self) -> np.ndarray: pts[self.skeleton.index(node)] = point.numpy() return pts + def update_skeleton(self): + """Update the points dictionary to match the skeleton. + + Points associated with nodes that are no longer in the skeleton will be removed. + + Additionally, the keys of the points dictionary will be ordered to match the + order of the nodes in the skeleton. + + Notes: + This method is useful when the skeleton has been updated (e.g., nodes + removed or reordered). + + However, it is recommended to use `Labels`-level methods (e.g., + `Labels.remove_nodes()`) when manipulating the skeleton as these will + automatically call this method on every instance. + + It is NOT necessary to call this method after renaming nodes. + """ + # Create a new dictionary to hold the updated points + new_points = {} + + # Iterate over the nodes in the skeleton + for node in self.skeleton.nodes: + # Get the point associated with the node + point = self.points.get(node, None) + + # If the point is not None, add it to the new dictionary + if point is not None: + new_points[node] = point + + # Update the points dictionary + self.points = new_points + @define class PredictedInstance(Instance): diff --git a/sleap_io/model/skeleton.py b/sleap_io/model/skeleton.py index 56424c43..d17d7844 100644 --- a/sleap_io/model/skeleton.py +++ b/sleap_io/model/skeleton.py @@ -542,10 +542,10 @@ def remove_node(self, node: NodeOrIndex): changes. It is recommended to use the `Labels.remove_nodes()` method which will - update all contained to reflect the changes made to the skeleton. + update all contained instances to reflect the changes made to the skeleton. To manually update instances after this method is called, call - `instance.update_skeleton()` on each instance that uses this skeleton. + `Instance.update_skeleton()` on each instance that uses this skeleton. """ self.remove_nodes([node]) @@ -556,16 +556,29 @@ def reorder_nodes(self, new_order: list[NodeOrIndex]): new_order: A list of node names, indices, or `Node` objects specifying the new order of the nodes. + Raises: + ValueError: If the new order of nodes is not the same length as the current + nodes. + Notes: This method should always be used when reordering nodes in the skeleton as it handles updating the lookup caches necessary for indexing nodes by name. - After reordering, instances using this skeleton do NOT need to be updated as + Warning: + After reordering, instances using this skeleton do not need to be updated as the nodes are stored by reference in the skeleton. - Raises: - ValueError: If the new order of nodes is not the same length as the current - nodes. + However, the order that points are stored in the instances will not be + updated to match the new order of the nodes in the skeleton. This should not + matter unless the ordering of the keys in the `Instance.points` dictionary + is used instead of relying on the skeleton node order. + + To make sure these are aligned, it is recommended to use the + `Labels.reorder_nodes()` method which will update all contained instances to + reflect the changes made to the skeleton. + + To manually update instances after this method is called, call + `Instance.update_skeleton()` on each instance that uses this skeleton. """ if len(new_order) != len(self.nodes): raise ValueError( diff --git a/tests/model/test_instance.py b/tests/model/test_instance.py index e3181ec3..4fe3f5f8 100644 --- a/tests/model/test_instance.py +++ b/tests/model/test_instance.py @@ -164,3 +164,36 @@ def test_predicted_instance(): str(inst) == "PredictedInstance(points=[[0.0, 1.0], [2.0, 3.0]], track=None, " "score=0.60, tracking_score=None)" ) + + +def test_instance_update_skeleton(): + skel = Skeleton(["A", "B", "C"]) + inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=skel) + + # No need to update on rename + skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) + assert inst["X"].x == 0 + assert inst["Y"].x == 1 + assert inst["Z"].x == 2 + assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]]) + + # Remove a node from the skeleton + Y = skel["Y"] + skel.remove_node("Y") + assert Y not in skel + + with pytest.raises(KeyError): + inst.numpy() # .numpy() breaks without update + assert Y in inst.points # and the points dict still has the old key + inst.update_skeleton() + assert Y not in inst.points # after update, the old key is gone + assert_equal(inst.numpy(), [[0, 0], [2, 2]]) + + # Reorder nodes + skel.reorder_nodes(["Z", "X"]) + assert_equal(inst.numpy(), [[2, 2], [0, 0]]) # .numpy() works without update + assert ( + list(inst.points.keys()) != skel.nodes + ) # but the points dict still has the old order + inst.update_skeleton() + assert list(inst.points.keys()) == skel.nodes # after update, the order is correct