Skip to content

Commit

Permalink
Add instance-level updating
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Nov 1, 2024
1 parent d43db72 commit e50025d
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 6 deletions.
33 changes: 33 additions & 0 deletions sleap_io/model/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,39 @@ def numpy(self) -> np.ndarray:
pts[self.skeleton.index(node)] = point.numpy()
return pts

def update_skeleton(self):
"""Update the points dictionary to match the skeleton.
Points associated with nodes that are no longer in the skeleton will be removed.
Additionally, the keys of the points dictionary will be ordered to match the
order of the nodes in the skeleton.
Notes:
This method is useful when the skeleton has been updated (e.g., nodes
removed or reordered).
However, it is recommended to use `Labels`-level methods (e.g.,
`Labels.remove_nodes()`) when manipulating the skeleton as these will
automatically call this method on every instance.
It is NOT necessary to call this method after renaming nodes.
"""
# Create a new dictionary to hold the updated points
new_points = {}

# Iterate over the nodes in the skeleton
for node in self.skeleton.nodes:
# Get the point associated with the node
point = self.points.get(node, None)

# If the point is not None, add it to the new dictionary
if point is not None:
new_points[node] = point

# Update the points dictionary
self.points = new_points


@define
class PredictedInstance(Instance):
Expand Down
25 changes: 19 additions & 6 deletions sleap_io/model/skeleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,10 +542,10 @@ def remove_node(self, node: NodeOrIndex):
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
update all contained instances to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_skeleton()` on each instance that uses this skeleton.
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])

Expand All @@ -556,16 +556,29 @@ def reorder_nodes(self, new_order: list[NodeOrIndex]):
new_order: A list of node names, indices, or `Node` objects specifying the
new order of the nodes.
Raises:
ValueError: If the new order of nodes is not the same length as the current
nodes.
Notes:
This method should always be used when reordering nodes in the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
After reordering, instances using this skeleton do NOT need to be updated as
Warning:
After reordering, instances using this skeleton do not need to be updated as
the nodes are stored by reference in the skeleton.
Raises:
ValueError: If the new order of nodes is not the same length as the current
nodes.
However, the order that points are stored in the instances will not be
updated to match the new order of the nodes in the skeleton. This should not
matter unless the ordering of the keys in the `Instance.points` dictionary
is used instead of relying on the skeleton node order.
To make sure these are aligned, it is recommended to use the
`Labels.reorder_nodes()` method which will update all contained instances to
reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
if len(new_order) != len(self.nodes):
raise ValueError(
Expand Down
33 changes: 33 additions & 0 deletions tests/model/test_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,36 @@ def test_predicted_instance():
str(inst) == "PredictedInstance(points=[[0.0, 1.0], [2.0, 3.0]], track=None, "
"score=0.60, tracking_score=None)"
)


def test_instance_update_skeleton():
skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=skel)

# No need to update on rename
skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
assert inst["X"].x == 0
assert inst["Y"].x == 1
assert inst["Z"].x == 2
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]])

# Remove a node from the skeleton
Y = skel["Y"]
skel.remove_node("Y")
assert Y not in skel

with pytest.raises(KeyError):
inst.numpy() # .numpy() breaks without update
assert Y in inst.points # and the points dict still has the old key
inst.update_skeleton()
assert Y not in inst.points # after update, the old key is gone
assert_equal(inst.numpy(), [[0, 0], [2, 2]])

# Reorder nodes
skel.reorder_nodes(["Z", "X"])
assert_equal(inst.numpy(), [[2, 2], [0, 0]]) # .numpy() works without update
assert (
list(inst.points.keys()) != skel.nodes
) # but the points dict still has the old order
inst.update_skeleton()
assert list(inst.points.keys()) == skel.nodes # after update, the order is correct

0 comments on commit e50025d

Please sign in to comment.