Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add skeleton manipulation utilities #136

Merged
merged 13 commits into from
Nov 1, 2024
Merged

Conversation

talmo
Copy link
Contributor

@talmo talmo commented Nov 1, 2024

This PR builds on #135 to add a number of methods for Skeleton manipulation, addressing #123 and more.

API changes:

Skeleton

  • __contains__(node: NodeOrIndex): Returns True if a node exists in the skeleton.
  • rebuild_cache(): Method allowing explicit regeneration of the caching attributes from the nodes.
  • Caching attributes are now named _name_to_node_cache and _node_to_ind_cache, better reflecting the mapping directionality.
  • require_node(node: NodeOrIndex, add_missing: bool = True): Returns a Node given a Node, int or str. If add_missing is True, the node is added or created, otherwise an IndexError is raised. This is helpful for flexibly converting between node representations with convenient existence handling.
  • add_nodes(list[Node | str]): Convenience method to add a list of nodes.
  • add_edges(edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]): Convenience method to add a list of edges.
  • rename_nodes(name_map: dict[NodeOrIndex, str] | list[str]): Method to rename nodes either by specifying a potentially partial mapping from node(s) to new name(s), or a list of new names. Handles updating both the Node.name attributes and the cache.
  • rename_node(old_name: NodeOrIndex, new_name: str): Shorter syntax for renaming a single node.
  • remove_nodes(nodes: list[NodeOrIndex]): Method for removing nodes from the skeleton and updating caches. Does NOT update corresponding instances.
  • remove_node(node: NodeOrIndex): Shorter syntax for removing a single node.
  • reorder_nodes(new_order: list[NodeOrIndex]): Method for setting the order of the nodes within the skeleton with cache updating. Does NOT update corresponding instances.

Instance/PredictedInstance

  • update_skeleton(): Updates the points attribute on the instance to reflect changes in the associated skeleton (removed nodes and reordering). This is called internally after updating the skeleton from the Labels level, but also exposed for more complex data manipulation workflows.
  • replace_skeleton(new_skeleton: Skeleton, node_map: dict[NodeOrIndex, NodeOrIndex] | None = None, rev_node_map: dict[NodeOrIndex, NodeOrIndex] | None = None): Method to replace the skeleton on the instance with optional capability to specify a node mapping so that data stored in the points attribute is retained and associated with the right nodes in the new skeleton. Mapping is specified in node_map from old to new nodes and defaults to mapping between node objects with the same name. rev_node_map maps new nodes to old nodes and is used internally when calling from the Labels level as it bypasses validation.

Labels

  • instances: Convenience property that returns a generator that loops over all labeled frames and returns all instances. This can be lazily iterated over without having to construct a huge list of all the instances.
  • rename_nodes(name_map: dict[NodeOrIndex, str] | list[str], skeleton: Skeleton | None = None): Method to rename nodes in a specified skeleton within the labels.
  • remove_nodes(nodes: list[NodeOrIndex], skeleton: Skeleton | None = None): Method to remove nodes in a specified skeleton within the labels. This also updates all instances associated with the skeleton, removing point data for the removed nodes.
  • reorder_nodes(new_order: list[NodeOrIndex], skeleton: Skeleton | None = None): Method to reorder nodes in a specified skeleton within the labels. This also updates all instances associated with the skeleton, reordering point data for the nodes.
  • replace_skeleton(new_skeleton: Skeleton, old_skeleton: Skeleton | None = None, node_map: dict[NodeOrIndex, NodeOrIndex] | None = None): Method to replace a skeleton entirely within the labels, updating all instances associated with the old skeleton to use the new skeleton, optionally with node remapping to retain previous point data.

Summary by CodeRabbit

Release Notes

  • New Features

    • Introduced new methods for renaming, removing, and reordering nodes in both the Skeleton and Labels classes.
    • Added an instances property in the Labels class for easier access to instance data.
    • Enhanced the Instance class with methods to update and replace skeletons.
  • Bug Fixes

    • Improved error handling for various operations involving nodes, ensuring robustness.
  • Tests

    • Expanded test coverage for the Skeleton, Labels, and Instance classes, validating new functionalities and error conditions.

Copy link
Contributor

coderabbitai bot commented Nov 1, 2024

Walkthrough

The changes introduce significant enhancements to the Skeleton class in sleap_io/model/skeleton.py, including a new type alias and methods for renaming nodes, managing node attributes, and maintaining internal caches. The Labels class receives a new property method instances for accessing instance data across labeled frames. Corresponding tests are added to validate the functionalities of renaming, removing, and reordering nodes, as well as ensuring the integrity of instances after skeleton modifications.

Changes

File Change Summary
sleap_io/model/skeleton.py Introduced type alias NodeOrIndex, added methods for renaming, adding, removing, and reordering nodes, and updated internal caches and type hints.
tests/model/test_skeleton.py Added test functions test_rename_nodes(), test_remove_nodes(), and test_reorder_nodes() to ensure correct functionality of node management.
sleap_io/model/labels.py Added property method instances(self) -> Iterator[Instance] and methods for renaming, removing, and reordering nodes in the Labels class.
tests/model/test_labels.py Added test functions to validate the new functionalities in the Labels class, including test_labels_instances(), test_labels_rename_nodes(), test_labels_remove_nodes(), and test_labels_reorder_nodes().
sleap_io/model/instance.py Added methods update_skeleton(self) and `replace_skeleton(self, new_skeleton: Skeleton, node_map: dict[NodeOrIndex, NodeOrIndex]
tests/model/test_instance.py Added test functions test_instance_update_skeleton() and test_instance_replace_skeleton() to validate behavior after skeleton updates.

Possibly related PRs

  • Make skeleton nodes mutable #135: The changes in this PR modify the Skeleton class and its Node instances, which are directly related to the significant updates made to the Skeleton class in the main PR, particularly regarding node management and error handling.

Poem

In the garden of nodes, where names do play,
A rabbit hops in, to rename and sway.
With a flick of the ear and a twitch of the nose,
New names emerge, as the old ones do close.
Errors are caught, like leaves in a breeze,
In the world of Skeletons, we frolic with ease! 🐇✨


📜 Recent review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 57ee01c and 3f493af.

📒 Files selected for processing (1)
  • sleap_io/model/skeleton.py (6 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

280-280: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


287-287: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


289-289: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


296-296: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


298-298: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


300-300: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


318-318: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


321-321: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


361-361: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


378-378: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


389-389: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


415-415: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


460-460: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


469-469: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


471-471: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

🔇 Additional comments (14)
sleap_io/model/skeleton.py (14)

74-76: Good use of type alias for NodeOrIndex

Defining NodeOrIndex enhances code readability and type checking.


113-116: Efficient cache update with _nodes_on_setattr

Using on_setattr callback to rebuild caches ensures nodes are always in sync.


118-121: Correctly configured nodes field with on_setattr

The nodes field is properly set up with factory and on_setattr for automatic cache management.


124-126: Initialization of cache dictionaries

Initializing _name_to_node_cache and _node_to_ind_cache prepares the skeleton for efficient node lookups.


132-132: Cache rebuild in post-initialization

Calling self.rebuild_cache() in __attrs_post_init__ ensures caches are up-to-date after initialization.


177-201: Robust rebuild_cache method

The rebuild_cache method effectively updates node caches, supporting both full and partial updates.


208-208: Updated edge_inds property with modern typing

The return type hint list[tuple[int, int]] reflects current typing standards.


221-221: Updated symmetry_inds property with modern typing

Consistently using list[tuple[int, int]] enhances type clarity.


285-304: Consistent node retrieval methods

The __getitem__ and __contains__ methods provide flexible node access by index, name, or node object.

🧰 Tools
🪛 Ruff

287-287: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


289-289: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


296-296: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


298-298: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


300-300: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


339-365: Flexible require_node method

The require_node method supports adding missing nodes, improving usability.

🧰 Tools
🪛 Ruff

361-361: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


366-395: Enhanced add_edge method with versatile input

Accepting various input types for src and dst increases the method's flexibility.

🧰 Tools
🪛 Ruff

378-378: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


389-389: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


428-493: Comprehensive rename_nodes method

The rename_nodes method handles multiple input forms and updates caches appropriately.

🧰 Tools
🪛 Ruff

460-460: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


469-469: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


471-471: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


494-567: Effective remove_nodes method with cache maintenance

Removing nodes and associated edges and symmetries while updating caches maintains integrity.


568-605: Useful reorder_nodes method

Reordering nodes enhances control over node organization while keeping caches consistent.


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Outside diff range and nitpick comments (1)
sleap_io/model/skeleton.py (1)

333-363: Consider enhancing type hints and documentation.

The documentation is comprehensive, but here are some suggestions for improvement:

  1. Simplify the type hint using the union operator (Python 3.10+):
-def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
+def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
  1. Add a note about the behavior of node references in the docstring:
Notes:
    This method should always be used when renaming nodes in the skeleton as it
    handles updating the lookup caches necessary for indexing nodes by name.

    After renaming, instances using this skeleton do NOT need to be updated as
    the nodes are stored by reference in the skeleton. This means that any
    existing Edge or Symmetry objects will continue to work correctly with the
    renamed nodes.
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 611947e and 0fe372c.

📒 Files selected for processing (2)
  • sleap_io/model/skeleton.py (1 hunks)
  • tests/model/test_skeleton.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

🔇 Additional comments (2)
tests/model/test_skeleton.py (1)

158-171: LGTM! Well-structured test implementation.

The test function follows good testing practices with clear arrange-act-assert patterns and proper error case coverage.

sleap_io/model/skeleton.py (1)

333-386: Verify the impact on existing instances.

The implementation looks correct, but let's verify that node references are preserved.

✅ Verification successful

Based on the search results, I can now generate the final response:

Node references are safely maintained across the codebase

The verification confirms that node name modifications are properly encapsulated:

  • The only direct node.name assignment is within the rename_nodes method itself
  • All other occurrences are read-only assertions in test files or equality comparisons
  • No other methods were found that modify node names
🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if there are any direct node.name assignments elsewhere in the codebase
# that might bypass this method and break the node map.

# Search for direct name assignments
echo "Searching for direct node.name assignments:"
rg "\.name\s*=" --type py

# Search for other methods that might modify node names
echo -e "\nSearching for methods that might modify node names:"
ast-grep --pattern 'def $FUNC($ARGS) { $$$ $NODE.name $$$}'

Length of output: 1374

🧰 Tools
🪛 Ruff

364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Comment on lines 158 to 171
def test_rename_nodes():
"""Test renaming nodes in the skeleton."""
skel = Skeleton(["A", "B", "C"])
skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
assert skel.node_names == ["X", "Y", "Z"]

skel.rename_nodes(["a", "b", "c"])
assert skel.node_names == ["a", "b", "c"]

with pytest.raises(ValueError):
skel.rename_nodes({"a": "b"})

with pytest.raises(ValueError):
skel.rename_nodes({"d": "e"})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding more test cases for comprehensive coverage.

While the current test cases cover the basic functionality, consider adding tests for:

  1. Verifying that edges and symmetries are preserved after renaming
  2. Testing partial renaming with dictionary (only renaming some nodes)
  3. Ensuring node indices remain consistent after renaming

Example test cases to add:

def test_rename_nodes_preserves_structure():
    """Test that renaming preserves edges and symmetries."""
    # Test edges preservation
    skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
    original_edges = skel.edge_inds.copy()
    skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
    assert skel.edge_inds == original_edges
    assert skel.edge_names == [("X", "Y"), ("Y", "Z")]

    # Test symmetries preservation
    skel = Skeleton(["L", "R", "C"])
    skel.add_symmetry("L", "R")
    original_symmetries = skel.symmetry_inds.copy()
    skel.rename_nodes(["Left", "Right", "Center"])
    assert skel.symmetry_inds == original_symmetries
    assert skel.symmetry_names == [("Left", "Right")]

def test_partial_rename():
    """Test renaming only some nodes."""
    skel = Skeleton(["A", "B", "C"])
    skel.rename_nodes({"A": "X", "C": "Z"})  # B remains unchanged
    assert skel.node_names == ["X", "B", "Z"]

Comment on lines 372 to 386
for old_name, new_name in name_map.items():
if type(old_name) == Node:
old_name = old_name.name
if type(new_name) == int:
new_name = self.nodes[new_name].name

if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")

node = self._node_name_map[old_name]
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Improve type checking and make node map updates safer.

The node map updates could be made safer and more maintainable:

+        # Validate input types and names first
+        updates = {}
         for old_name, new_name in name_map.items():
-            if type(old_name) == Node:
+            if isinstance(old_name, Node):
                 old_name = old_name.name
-            if type(new_name) == int:
+            if isinstance(new_name, int):
                 new_name = self.nodes[new_name].name
+            if not isinstance(new_name, str):
+                raise ValueError(f"New name must be a string, got {type(new_name)}")

             if old_name not in self._node_name_map:
                 raise ValueError(f"Node '{old_name}' not found in the skeleton.")
             if new_name in self._node_name_map:
                 raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
+            updates[old_name] = (self._node_name_map[old_name], new_name)

+        # Apply updates after validation
+        for old_name, (node, new_name) in updates.items():
-            node = self._node_name_map[old_name]
             node.name = new_name
             self._node_name_map[new_name] = node
             del self._node_name_map[old_name]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for old_name, new_name in name_map.items():
if type(old_name) == Node:
old_name = old_name.name
if type(new_name) == int:
new_name = self.nodes[new_name].name
if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
node = self._node_name_map[old_name]
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]
# Validate input types and names first
updates = {}
for old_name, new_name in name_map.items():
if isinstance(old_name, Node):
old_name = old_name.name
if isinstance(new_name, int):
new_name = self.nodes[new_name].name
if not isinstance(new_name, str):
raise ValueError(f"New name must be a string, got {type(new_name)}")
if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
updates[old_name] = (self._node_name_map[old_name], new_name)
# Apply updates after validation
for old_name, (node, new_name) in updates.items():
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]
🧰 Tools
🪛 Ruff

373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Copy link

codecov bot commented Nov 1, 2024

Codecov Report

Attention: Patch coverage is 96.04520% with 7 lines in your changes missing coverage. Please review.

Project coverage is 96.33%. Comparing base (611947e) to head (3f493af).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
sleap_io/model/skeleton.py 94.23% 6 Missing ⚠️
sleap_io/model/labels.py 97.87% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #136      +/-   ##
==========================================
+ Coverage   96.13%   96.33%   +0.19%     
==========================================
  Files          17       17              
  Lines        2124     2237     +113     
==========================================
+ Hits         2042     2155     +113     
  Misses         82       82              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Outside diff range and nitpick comments (4)
tests/model/test_skeleton.py (2)

177-193: LGTM, consider adding edge case tests.

The test covers all node identification methods and basic error handling well. Consider adding tests for:

  1. Edge cases (renaming first/last node)
  2. Special characters in node names
  3. Verification that non-target nodes remain unchanged

156-158: Maintain consistent spacing between test functions.

To follow PEP 8 and maintain consistency with the rest of the file, use two blank lines between test functions.

 def test_add_symmetry():
     # ... existing code ...
-

-
 def test_rename_nodes():
     # ... existing code ...
-

-
 def test_rename_node():

Also applies to: 176-177

tests/model/test_labels.py (1)

638-667: LGTM! Consider adding edge cases.

The test implementation looks good and follows the established patterns. It effectively tests the basic functionality of the instances property.

Consider adding these edge cases to make the test more comprehensive:

  1. Empty labeled frames (no instances)
  2. Predicted instances vs. regular instances
  3. Instance filtering (e.g., by track, skeleton)
def test_labels_instances():
    labels = Labels()
    
    # Test empty labels
    assert len(list(labels.instances)) == 0
    
    # Test empty labeled frame
    labels.append(LabeledFrame(video=Video("test.mp4"), frame_idx=0))
    assert len(list(labels.instances)) == 0
    
    # Test regular vs predicted instances
    labels.append(
        LabeledFrame(
            video=Video("test.mp4"),
            frame_idx=1,
            instances=[
                Instance.from_numpy(
                    np.array([[0, 1], [2, 3]]), 
                    skeleton=Skeleton(["A", "B"])
                )
            ],
            predicted_instances=[
                PredictedInstance.from_numpy(
                    np.array([[4, 5], [6, 7]]), 
                    skeleton=Skeleton(["A", "B"])
                )
            ]
        )
    )
    assert len(list(labels.instances)) == 1  # Should only count regular instances
    
    # Test instance filtering
    track = Track("test_track")
    labels.append(
        LabeledFrame(
            video=labels.video,
            frame_idx=2,
            instances=[
                Instance.from_numpy(
                    np.array([[0, 1], [2, 3]]),
                    skeleton=labels.skeleton,
                    track=track
                )
            ]
        )
    )
    # Add assertions for filtering by track, skeleton, etc.
sleap_io/model/labels.py (1)

462-465: LGTM! Well-implemented property using memory-efficient generator expression.

The implementation efficiently provides access to all instances across labeled frames using a generator expression, which is memory-efficient for large datasets as it doesn't materialize the entire list at once.

The use of a generator expression here is particularly beneficial when dealing with large datasets, as it allows for streaming access to instances without loading all of them into memory simultaneously. This approach supports efficient iteration and memory usage, especially important for processing large collections of labeled frames.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 0fe372c and 387bab1.

📒 Files selected for processing (4)
  • sleap_io/model/labels.py (2 hunks)
  • sleap_io/model/skeleton.py (1 hunks)
  • tests/model/test_labels.py (1 hunks)
  • tests/model/test_skeleton.py (1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

🔇 Additional comments (3)
tests/model/test_skeleton.py (1)

158-175: LGTM, but test coverage could be expanded.

The test cases cover the basic functionality and error handling well. However, as mentioned in the previous review, consider adding tests for:

  1. Edge preservation after renaming
  2. Symmetry preservation
  3. Partial renaming scenarios
  4. Node indices consistency
sleap_io/model/skeleton.py (1)

388-396: LGTM!

The implementation is clean and effectively reuses the rename_nodes method, maintaining consistent behavior across both methods.

sleap_io/model/labels.py (1)

24-24: LGTM! Appropriate addition of Iterator type hint import.

The import of Iterator from typing is correctly added to support type hinting of the new instances property.

Comment on lines 333 to 387
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
"""Rename nodes in the skeleton.

Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.

If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.

Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.

After renaming, instances using this skeleton do NOT need to be updated as
the nodes are stored by reference in the skeleton.

Example:
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> skel.node_names
["X", "Y", "Z"]
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]

Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
"""
if type(name_map) == list:
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}

for old_name, new_name in name_map.items():
if type(old_name) == Node:
old_name = old_name.name
if type(old_name) == int:
old_name = self.nodes[old_name].name

if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")

node = self._node_name_map[old_name]
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add validation for empty input and string type.

While the existing implementation handles most edge cases well, consider adding these additional validations:

 def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
+    if not name_map:
+        raise ValueError("name_map cannot be empty")
+
     if isinstance(name_map, list):
+        if not all(isinstance(name, str) for name in name_map):
+            raise ValueError("All new names must be strings")
         if len(name_map) != len(self.nodes):
             raise ValueError(
                 "List of new node names must be the same length as the current "
                 "nodes."
             )
         name_map = {node: name for node, name in zip(self.nodes, name_map)}

     for old_name, new_name in name_map.items():
+        if not isinstance(new_name, str):
+            raise ValueError(f"New name must be a string, got {type(new_name)}")
         if type(old_name) == Node:
             old_name = old_name.name
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
"""Rename nodes in the skeleton.
Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.
If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.
Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.
After renaming, instances using this skeleton do NOT need to be updated as
the nodes are stored by reference in the skeleton.
Example:
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> skel.node_names
["X", "Y", "Z"]
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]
Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
"""
if type(name_map) == list:
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}
for old_name, new_name in name_map.items():
if type(old_name) == Node:
old_name = old_name.name
if type(old_name) == int:
old_name = self.nodes[old_name].name
if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
node = self._node_name_map[old_name]
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
"""Rename nodes in the skeleton.
Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.
If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.
Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.
After renaming, instances using this skeleton do NOT need to be updated as
the nodes are stored by reference in the skeleton.
Example:
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> skel.node_names
["X", "Y", "Z"]
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]
Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
"""
if not name_map:
raise ValueError("name_map cannot be empty")
if isinstance(name_map, list):
if not all(isinstance(name, str) for name in name_map):
raise ValueError("All new names must be strings")
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}
for old_name, new_name in name_map.items():
if not isinstance(new_name, str):
raise ValueError(f"New name must be a string, got {type(new_name)}")
if type(old_name) == Node:
old_name = old_name.name
if type(old_name) == int:
old_name = self.nodes[old_name].name
if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
node = self._node_name_map[old_name]
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]
🧰 Tools
🪛 Ruff

364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Outside diff range and nitpick comments (5)
tests/model/test_instance.py (3)

180-190: Consider enhancing error message verification.

While the test effectively verifies the behavior during node removal, it could be more robust by checking the specific error message.

Consider updating the error check to verify the specific message:

-    with pytest.raises(KeyError):
+    with pytest.raises(KeyError, match="Node 'Y' not found in skeleton"):
         inst.numpy()  # .numpy() breaks without update

195-197: Improve assertion readability.

The multi-line assertion could be more readable using a separate variable.

Consider refactoring to:

-    assert (
-        list(inst.points.keys()) != skel.nodes
-    )  # but the points dict still has the old order
+    points_keys = list(inst.points.keys())
+    assert points_keys != skel.nodes  # but the points dict still has the old order

169-169: Add docstring to test function.

The test function is missing a docstring describing its purpose and the scenarios being tested.

Add a descriptive docstring:

 def test_instance_update_skeleton():
+    """Test Instance behavior when Skeleton is modified.
+    
+    Verifies:
+    1. Node renaming doesn't require instance update
+    2. Node removal requires update_skeleton() call
+    3. Node reordering affects numpy() output and requires update
+    """
tests/model/test_skeleton.py (1)

226-245: Add test case for symmetry preservation during reordering.

While the test cases cover node reordering and its impact on edges, consider adding a test case to verify that symmetry relationships are preserved after reordering nodes.

Example test case:

def test_reorder_nodes_preserves_symmetry():
    skel = Skeleton(["A", "BL", "BR", "C"])
    skel.add_symmetry("BL", "BR")
    original_symmetries = skel.symmetry_inds.copy()
    
    skel.reorder_nodes(["C", "BR", "A", "BL"])
    assert skel.symmetry_inds != original_symmetries  # indices should update
    assert skel.symmetry_names == [("BL", "BR")]  # but relationships should remain
sleap_io/model/instance.py (1)

311-328: Enhance method documentation with type hints and examples.

The documentation is clear but could be more comprehensive. Consider:

  1. Adding return type hint (-> None)
  2. Documenting any exceptions that might be raised
  3. Adding a usage example to illustrate the behavior
-    def update_skeleton(self):
+    def update_skeleton(self) -> None:
         """Update the points dictionary to match the skeleton.
 
         Points associated with nodes that are no longer in the skeleton will be removed.
 
         Additionally, the keys of the points dictionary will be ordered to match the
         order of the nodes in the skeleton.
 
         Notes:
             This method is useful when the skeleton has been updated (e.g., nodes
             removed or reordered).
 
             However, it is recommended to use `Labels`-level methods (e.g.,
             `Labels.remove_nodes()`) when manipulating the skeleton as these will
             automatically call this method on every instance.
 
             It is NOT necessary to call this method after renaming nodes.
+
+        Example:
+            >>> instance = Instance(points=points, skeleton=skeleton)
+            >>> skeleton.remove_nodes(['node_to_remove'])
+            >>> instance.update_skeleton()  # Points for removed node are dropped
+
+        Returns:
+            None
+
+        Raises:
+            AttributeError: If skeleton is not set
         """
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 387bab1 and e50025d.

📒 Files selected for processing (4)
  • sleap_io/model/instance.py (1 hunks)
  • sleap_io/model/skeleton.py (6 hunks)
  • tests/model/test_instance.py (1 hunks)
  • tests/model/test_skeleton.py (3 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

264-264: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


271-271: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


273-273: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


280-280: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


282-282: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


284-284: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


302-302: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


305-305: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


345-345: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


362-362: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


399-399: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


444-444: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


453-453: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


455-455: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

🔇 Additional comments (11)
tests/model/test_instance.py (2)

169-179: LGTM! Well-structured test setup and rename verification.

The test setup and node renaming verification is clear, comprehensive, and follows testing best practices by checking both direct node access and array representation.


169-199: 🛠️ Refactor suggestion

Consider adding tests for edge cases.

The test provides good coverage of basic operations, but could be enhanced with additional scenarios:

  1. Multiple simultaneous operations (e.g., rename + remove)
  2. Edge case where all nodes are removed
  3. Invalid node names in operations

Let's check if these cases are covered elsewhere:

tests/model/test_skeleton.py (5)

107-108: LGTM! Enhanced test coverage for node addition.

The additional assertions properly verify node presence, indexing, and multiple node addition scenarios.

Also applies to: 110-115, 122-124


145-147: LGTM! Good coverage of multiple edge addition.

The test case properly verifies the add_edges method with multiple edges and confirms correct edge indices.


169-193: Add tests for edge and symmetry preservation after renaming.

While the current test cases cover basic renaming functionality and error cases, they don't verify that edges and symmetries are preserved after renaming operations.


194-224: LGTM! Thorough testing of node removal functionality.

The test cases comprehensively verify:

  • Node index updates after removal
  • Edge list maintenance
  • Symmetry relationship updates
  • Both single and multiple node removal scenarios
  • Error handling for invalid cases

169-245: LGTM! Comprehensive test suite for skeleton manipulation utilities.

The test suite provides good coverage of the new functionality with:

  • Clear test organization
  • Comprehensive error case handling
  • Verification of data structure integrity
  • Consistent testing style

Minor suggestions for additional test cases have been provided above, but the current coverage is already robust.

sleap_io/model/instance.py (1)

311-343: Verify integration with Labels-level methods.

The documentation mentions that Labels-level methods should call this method automatically. Let's verify this integration to ensure proper skeleton synchronization.

sleap_io/model/skeleton.py (3)

Line range hint 74-110: LGTM! Type improvements enhance code safety and readability.

The introduction of NodeOrIndex type alias and the renaming of cache fields to more descriptive names (_name_to_node_cache and _node_to_ind_cache) improve code clarity. The callback method _nodes_on_setattr ensures cache consistency.


412-467: LGTM! Well-implemented node renaming with proper validation.

The node renaming implementation is thorough with:

  • Comprehensive input validation
  • Clear error messages
  • Proper cache updates
  • Excellent documentation including examples
🧰 Tools
🪛 Ruff

444-444: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


453-453: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


455-455: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


552-589: LGTM! Well-implemented node reordering with clear documentation.

The node reordering implementation is solid with:

  • Proper validation of the new order
  • Clear warnings about instance updates
  • Comprehensive documentation about potential side effects

Comment on lines +329 to +342
# Create a new dictionary to hold the updated points
new_points = {}

# Iterate over the nodes in the skeleton
for node in self.skeleton.nodes:
# Get the point associated with the node
point = self.points.get(node, None)

# If the point is not None, add it to the new dictionary
if point is not None:
new_points[node] = point

# Update the points dictionary
self.points = new_points
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider optimizing for the no-change case.

The current implementation always creates a new dictionary, even when no changes are needed. Consider adding a fast path for when the points dictionary already matches the skeleton.

     def update_skeleton(self) -> None:
+        # Fast path: check if update is needed
+        if (set(self.points.keys()) == set(self.skeleton.nodes) and 
+            list(self.points.keys()) == list(self.skeleton.nodes)):
+            return
+
         # Create a new dictionary to hold the updated points
         new_points = {}
 
         # Iterate over the nodes in the skeleton
         for node in self.skeleton.nodes:
             # Get the point associated with the node
             point = self.points.get(node, None)
 
             # If the point is not None, add it to the new dictionary
             if point is not None:
                 new_points[node] = point
 
         # Update the points dictionary
         self.points = new_points

Committable suggestion skipped: line range outside the PR's diff.

Comment on lines +161 to +185
def rebuild_cache(self, nodes: list[Node] | None = None):
"""Rebuild the node name/index to `Node` map caches.

Args:
nodes: A list of `Node` objects to update the cache with. If not provided,
the cache will be updated with the current nodes in the skeleton. If
nodes are provided, the cache will be updated with the provided nodes,
but the current nodes in the skeleton will not be updated. Default is
`None`.

Notes:
This function should be called when nodes or node list is mutated to update
the lookup caches for indexing nodes by name or `Node` object.

This is done automatically when nodes are added or removed from the skeleton
using the convenience methods in this class.

This method only needs to be used when manually mutating nodes or the node
list directly.
"""
if nodes is None:
nodes = self.nodes
self._name_to_node_cache = {node.name: node for node in nodes}
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding error handling for duplicate node names.

While the cache rebuilding logic is solid, it might silently overwrite entries if there are duplicate node names.

 def rebuild_cache(self, nodes: list[Node] | None = None):
     if nodes is None:
         nodes = self.nodes
+    # Check for duplicate names
+    name_counts = {}
+    for node in nodes:
+        name_counts[node.name] = name_counts.get(node.name, 0) + 1
+        if name_counts[node.name] > 1:
+            raise ValueError(f"Duplicate node name found: {node.name}")
     self._name_to_node_cache = {node.name: node for node in nodes}
     self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def rebuild_cache(self, nodes: list[Node] | None = None):
"""Rebuild the node name/index to `Node` map caches.
Args:
nodes: A list of `Node` objects to update the cache with. If not provided,
the cache will be updated with the current nodes in the skeleton. If
nodes are provided, the cache will be updated with the provided nodes,
but the current nodes in the skeleton will not be updated. Default is
`None`.
Notes:
This function should be called when nodes or node list is mutated to update
the lookup caches for indexing nodes by name or `Node` object.
This is done automatically when nodes are added or removed from the skeleton
using the convenience methods in this class.
This method only needs to be used when manually mutating nodes or the node
list directly.
"""
if nodes is None:
nodes = self.nodes
self._name_to_node_cache = {node.name: node for node in nodes}
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}
def rebuild_cache(self, nodes: list[Node] | None = None):
"""Rebuild the node name/index to `Node` map caches.
Args:
nodes: A list of `Node` objects to update the cache with. If not provided,
the cache will be updated with the current nodes in the skeleton. If
nodes are provided, the cache will be updated with the provided nodes,
but the current nodes in the skeleton will not be updated. Default is
`None`.
Notes:
This function should be called when nodes or node list is mutated to update
the lookup caches for indexing nodes by name or `Node` object.
This is done automatically when nodes are added or removed from the skeleton
using the convenience methods in this class.
This method only needs to be used when manually mutating nodes or the node
list directly.
"""
if nodes is None:
nodes = self.nodes
# Check for duplicate names
name_counts = {}
for node in nodes:
name_counts[node.name] = name_counts.get(node.name, 0) + 1
if name_counts[node.name] > 1:
raise ValueError(f"Duplicate node name found: {node.name}")
self._name_to_node_cache = {node.name: node for node in nodes}
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}

Comment on lines 478 to 551
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.

Args:
nodes: A list of node names, indices, or `Node` objects to remove.

Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.

Any edges and symmetries that are connected to the removed nodes will also
be removed.

Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.

It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.

To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]

# Remove nodes from the skeleton.
for node in rm_node_objs:
self.nodes.remove(node)
del self._name_to_node_cache[node.name]

# Remove edges connected to the removed nodes.
self.edges = [
edge
for edge in self.edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]

# Remove symmetries connected to the removed nodes.
self.symmetries = [
symmetry
for symmetry in self.symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]

# Update node index map.
self.rebuild_cache()

def remove_node(self, node: NodeOrIndex):
"""Remove a single node from the skeleton.

Args:
node: The node to remove. Can be specified as a string name, integer index,
or `Node` object.

Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.

Any edges and symmetries that are connected to the removed node will also be
removed.

Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.

It is recommended to use the `Labels.remove_nodes()` method which will
update all contained instances to reflect the changes made to the skeleton.

To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making node removal atomic.

While the node removal implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state.

 def remove_nodes(self, nodes: list[NodeOrIndex]):
     # Standardize input and make a pre-mutation copy before keys are changed.
     rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
+    
+    # Create copies of current state
+    new_nodes = self.nodes.copy()
+    new_edges = self.edges.copy()
+    new_symmetries = self.symmetries.copy()
+    new_name_cache = self._name_to_node_cache.copy()
+    
+    try:
+        # Remove nodes
+        for node in rm_node_objs:
+            new_nodes.remove(node)
+            del new_name_cache[node.name]
+        
+        # Update edges and symmetries
+        new_edges = [
+            edge for edge in new_edges
+            if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
+        ]
+        new_symmetries = [
+            symmetry for symmetry in new_symmetries
+            if symmetry.nodes.isdisjoint(rm_node_objs)
+        ]
+        
+        # Commit changes
+        self.nodes = new_nodes
+        self.edges = new_edges
+        self.symmetries = new_symmetries
+        self._name_to_node_cache = new_name_cache
+        self.rebuild_cache()
+    except Exception as e:
+        # Rollback on error
+        raise ValueError(f"Failed to remove nodes: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
# Remove nodes from the skeleton.
for node in rm_node_objs:
self.nodes.remove(node)
del self._name_to_node_cache[node.name]
# Remove edges connected to the removed nodes.
self.edges = [
edge
for edge in self.edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]
# Remove symmetries connected to the removed nodes.
self.symmetries = [
symmetry
for symmetry in self.symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]
# Update node index map.
self.rebuild_cache()
def remove_node(self, node: NodeOrIndex):
"""Remove a single node from the skeleton.
Args:
node: The node to remove. Can be specified as a string name, integer index,
or `Node` object.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
Any edges and symmetries that are connected to the removed node will also be
removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained instances to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
# Create copies of current state
new_nodes = self.nodes.copy()
new_edges = self.edges.copy()
new_symmetries = self.symmetries.copy()
new_name_cache = self._name_to_node_cache.copy()
try:
# Remove nodes
for node in rm_node_objs:
new_nodes.remove(node)
del new_name_cache[node.name]
# Update edges and symmetries
new_edges = [
edge for edge in new_edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]
new_symmetries = [
symmetry for symmetry in new_symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]
# Commit changes
self.nodes = new_nodes
self.edges = new_edges
self.symmetries = new_symmetries
self._name_to_node_cache = new_name_cache
self.rebuild_cache()
except Exception as e:
# Rollback on error
raise ValueError(f"Failed to remove nodes: {str(e)}")
def remove_node(self, node: NodeOrIndex):
"""Remove a single node from the skeleton.
Args:
node: The node to remove. Can be specified as a string name, integer index,
or `Node` object.
Notes:
This method should always be used when removing nodes from the skeleton as
it handles updating the lookup caches necessary for indexing nodes by name.
Any edges and symmetries that are connected to the removed node will also be
removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained instances to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])

Comment on lines +269 to +349
def __getitem__(self, idx: NodeOrIndex) -> Node:
"""Return a `Node` when indexing by name or integer."""
if type(idx) == int:
return self.nodes[idx]
elif type(idx) == str:
return self._node_name_map[idx]
return self._name_to_node_cache[idx]
else:
raise IndexError(f"Invalid indexing argument for skeleton: {idx}")

def __contains__(self, node: NodeOrIndex) -> bool:
"""Check if a node is in the skeleton."""
if type(node) == str:
return node in self._name_to_node_cache
elif type(node) == Node:
return node in self.nodes
elif type(node) == int:
return 0 <= node < len(self.nodes)
else:
raise ValueError(f"Invalid node type for skeleton: {node}")

def add_node(self, node: Node | str):
"""Add a `Node` to the skeleton.

Args:
node: A `Node` object or a string name to create a new node.

Raises:
ValueError: If the node already exists in the skeleton or if the node is
not specified as a `Node` or string.
"""
node_name = node.name if type(node) == Node else node
if node_name in self._node_name_map:
raise ValueError(f"Node '{node_name}' already exists in the skeleton.")
if node in self:
raise ValueError(f"Node '{node}' already exists in the skeleton.")

if type(node) == str:
node = Node(node)
if node not in self.nodes:
self.nodes.append(node)
self._update_node_map(None, self.nodes)

def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None):
if type(node) != Node:
raise ValueError(f"Invalid node type: {node} ({type(node)})")

self.nodes.append(node)

# Atomic update of the cache.
self._name_to_node_cache[node.name] = node
self._node_to_ind_cache[node] = len(self.nodes) - 1

def add_nodes(self, nodes: list[Node | str]):
"""Add multiple `Node`s to the skeleton.

Args:
nodes: A list of `Node` objects or string names to create new nodes.
"""
for node in nodes:
self.add_node(node)

def require_node(self, node: NodeOrIndex, add_missing: bool = True) -> Node:
"""Return a `Node` object, handling indexing and adding missing nodes.

Args:
node: A `Node` object, name or index.
add_missing: If `True`, missing nodes will be added to the skeleton. If
`False`, an error will be raised if the node is not found. Default is
`True`.

Returns:
The `Node` object.

Raises:
IndexError: If the node is not found in the skeleton and `add_missing` is
`False`.
"""
if node not in self:
if add_missing:
self.add_node(node)
else:
raise IndexError(f"Node '{node}' not found in the skeleton.")

if type(node) == Node:
return node

return self[node]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace type() comparisons with isinstance() checks.

Following Python best practices, use isinstance() for type checking as it handles inheritance properly.

-    if type(idx) == int:
+    if isinstance(idx, int):
         return self.nodes[idx]
-    elif type(idx) == str:
+    elif isinstance(idx, str):
         return self._name_to_node_cache[idx]

-    if type(node) == str:
+    if isinstance(node, str):
         return node in self._name_to_node_cache
-    elif type(node) == Node:
+    elif isinstance(node, Node):
         return node in self.nodes
-    elif type(node) == int:
+    elif isinstance(node, int):
         return 0 <= node < len(self.nodes)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff

271-271: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


273-273: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


280-280: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


282-282: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


284-284: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


302-302: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


305-305: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


345-345: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Outside diff range and nitpick comments (7)
tests/model/test_instance.py (2)

169-200: LGTM! Well-structured test cases for skeleton updates.

The test cases effectively cover the three main skeleton manipulation scenarios (renaming, removal, reordering) with good assertions for both success and error cases.

Consider enhancing the reordering test with explicit assertions for the numpy array values before and after update to make the expected behavior more clear. For example:

# Reorder nodes
old_order = inst.numpy()  # capture original order
skel.reorder_nodes(["Z", "X"])
assert_equal(inst.numpy(), [[2, 2], [0, 0]])  # verify immediate reorder effect
inst.update_skeleton()
assert_equal(inst.numpy(), old_order)  # verify order is preserved after update

202-232: LGTM! Comprehensive test coverage for skeleton replacement scenarios.

The test cases effectively cover full replacement, partial replacement, and the fast path optimization with good assertions for data integrity and cleanup.

Consider adding a performance comparison for the fast path test to validate the optimization:

# Fast path performance comparison
import time

# Standard path timing
start = time.perf_counter()
inst.replace_skeleton(new_skel, node_map={"A": "X", "B": "Y", "C": "Z"})
standard_time = time.perf_counter() - start

# Fast path timing
start = time.perf_counter()
inst.replace_skeleton(new_skel, rev_node_map=rev_node_map)
fast_time = time.perf_counter() - start

# Verify performance improvement
assert fast_time < standard_time, "Fast path should be faster than standard path"
sleap_io/model/instance.py (1)

343-394: Consider enhancing documentation and input validation.

The implementation is solid, but could benefit from:

  1. Input validation for new_skeleton to ensure it's not None
  2. Adding an Examples section to the docstring showing common usage patterns

Add type validation and enhance docstring:

     def replace_skeleton(
         self,
         new_skeleton: Skeleton,
         node_map: dict[NodeOrIndex, NodeOrIndex] | None = None,
         rev_node_map: dict[NodeOrIndex, NodeOrIndex] | None = None,
     ):
         """Replace the skeleton associated with the instance.
 
         The points dictionary will be updated to match the new skeleton.
 
         Args:
             new_skeleton: The new `Skeleton` to associate with the instance.
             node_map: Dictionary mapping nodes in the old skeleton to nodes in the new
                 skeleton. Keys and values can be specified as `Node` objects, integer
                 indices, or string names. If not provided, only nodes with identical
                 names will be mapped. Points associated with unmapped nodes will be
                 removed.
             rev_node_map: Dictionary mapping nodes in the new skeleton to nodes in the
                 old skeleton. This is used internally when calling from
                 `Labels.replace_skeleton()` as it is more efficient to compute this
                 mapping once and pass it to all instances. No validation is done on this
                 mapping, so nodes are expected to be `Node` objects.
+
+        Examples:
+            >>> # Replace with new skeleton, automatically mapping nodes by name
+            >>> instance.replace_skeleton(new_skeleton)
+            >>>
+            >>> # Replace with explicit node mapping
+            >>> instance.replace_skeleton(
+            ...     new_skeleton,
+            ...     node_map={"head": "nose", "tail": "tail_tip"}
+            ... )
         """
+        if new_skeleton is None:
+            raise ValueError("new_skeleton cannot be None")
+
         if rev_node_map is None:
tests/model/test_labels.py (2)

670-682: Consider adding edge cases to the node renaming tests.

While the test covers the basic functionality well, consider adding tests for:

  • Attempting to rename non-existent nodes
  • Renaming to already existing node names
  • Empty or invalid node names

703-719: Consider adding validation tests for node reordering.

While the test covers the basic functionality well, consider adding tests for:

  • Reordering with missing nodes
  • Reordering with duplicate nodes
  • Reordering with invalid node names
sleap_io/model/labels.py (2)

468-585: LGTM: Well-implemented node manipulation methods with thorough documentation.

The methods properly handle both skeleton and instance updates, with comprehensive error handling. However, there's an opportunity to reduce code duplication.

Consider extracting the skeleton validation logic into a private helper method to reduce duplication. Here's a suggested implementation:

+ def _validate_and_get_skeleton(self, skeleton: Skeleton | None = None) -> Skeleton:
+     """Validate and return the skeleton to operate on.
+     
+     Args:
+         skeleton: Optional skeleton to validate. If None, uses the default skeleton.
+         
+     Returns:
+         The validated skeleton to use.
+         
+     Raises:
+         ValueError: If there is more than one skeleton but none was specified.
+     """
+     if skeleton is None:
+         if len(self.skeletons) != 1:
+             raise ValueError(
+                 "Skeleton must be specified when there is more than one skeleton in "
+                 "the labels."
+             )
+         return self.skeleton
+     return skeleton

  def rename_nodes(
      self,
      name_map: dict[NodeOrIndex, str] | list[str],
      skeleton: Skeleton | None = None,
  ):
-     if skeleton is None:
-         if len(self.skeletons) != 1:
-             raise ValueError(
-                 "Skeleton must be specified when there is more than one skeleton in "
-                 "the labels."
-             )
-         skeleton = self.skeleton
+     skeleton = self._validate_and_get_skeleton(skeleton)
      # ... rest of the method

This refactor would:

  1. Reduce code duplication
  2. Make the validation logic more maintainable
  3. Keep the same functionality and error handling

586-647: LGTM: Comprehensive skeleton replacement implementation with proper node mapping.

The method handles both explicit and automatic node mapping well, with appropriate warnings about data loss. However, there's an opportunity for performance optimization.

Consider optimizing the instance update loop for better performance with large datasets:

- for inst in self.instances:
-     if inst.skeleton == old_skeleton:
-         inst.replace_skeleton(new_skeleton, rev_node_map=rev_node_map)
+ # Filter instances first to avoid checking skeleton equality for every instance
+ affected_instances = (inst for inst in self.instances if inst.skeleton == old_skeleton)
+ for inst in affected_instances:
+     inst.replace_skeleton(new_skeleton, rev_node_map=rev_node_map)

This optimization:

  1. Reduces the number of skeleton equality checks
  2. Is more memory efficient by using a generator expression
  3. Maintains the same functionality
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between e50025d and 897752b.

📒 Files selected for processing (5)
  • sleap_io/model/instance.py (2 hunks)
  • sleap_io/model/labels.py (2 hunks)
  • sleap_io/model/skeleton.py (6 hunks)
  • tests/model/test_instance.py (1 hunks)
  • tests/model/test_labels.py (2 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

264-264: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


271-271: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


273-273: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


280-280: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


282-282: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


284-284: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


302-302: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


305-305: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


345-345: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


362-362: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


373-373: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


399-399: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


444-444: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


453-453: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


455-455: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

tests/model/test_labels.py

3-3: numpy.testing.assert_equal imported but unused

Remove unused import: numpy.testing.assert_equal

(F401)

🔇 Additional comments (11)
sleap_io/model/instance.py (2)

14-14: LGTM!

The import is correctly placed and necessary for type hints in the new replace_skeleton method.


312-341: Implementation is correct but can be optimized.

The method correctly handles point updates when the skeleton changes, but could benefit from optimization.

The previous review comment about optimizing for the no-change case is still valid. Consider adding a fast path check before creating a new dictionary.

sleap_io/model/skeleton.py (4)

74-75: LGTM! Good addition of type alias.

The NodeOrIndex type alias improves code readability and type safety by clearly defining acceptable node reference types.


323-349: LGTM! Well-implemented utility method.

The require_node method provides a robust way to handle node references with good error handling and flexibility.

🧰 Tools
🪛 Ruff

345-345: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


552-589: LGTM! Well-implemented reordering functionality.

The implementation includes proper validation and clear documentation about potential side effects on instances.


161-185: 🛠️ Refactor suggestion

Consider adding duplicate node name detection.

While the cache rebuilding logic is solid, it might silently overwrite entries if there are duplicate node names.

 def rebuild_cache(self, nodes: list[Node] | None = None):
     if nodes is None:
         nodes = self.nodes
+    # Check for duplicate names
+    name_counts = {}
+    for node in nodes:
+        name_counts[node.name] = name_counts.get(node.name, 0) + 1
+        if name_counts[node.name] > 1:
+            raise ValueError(f"Duplicate node name found: {node.name}")
     self._name_to_node_cache = {node.name: node for node in nodes}
     self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}

Likely invalid or redundant comment.

tests/model/test_labels.py (3)

638-668: LGTM! Well-structured test for the instances property.

The test thoroughly covers both single and multiple instance scenarios, ensuring proper instance counting across frames.


684-701: LGTM! Comprehensive test coverage for node removal.

The test effectively verifies:

  • Node removal from skeleton
  • Impact on instance data
  • Error handling for invalid cases

721-749: LGTM! Excellent test coverage for skeleton replacement.

The test comprehensively covers all skeleton replacement scenarios:

  • Full node mapping
  • Partial (inferred) mapping
  • No mapping case
  • Proper handling of instance data updates
sleap_io/model/labels.py (2)

16-16: LGTM: Import changes are appropriate.

The new imports support the added functionality for skeleton manipulation and type hints.

Also applies to: 24-24, 26-26


463-466: LGTM: Efficient implementation of instances property.

The use of a generator expression instead of a list comprehension is memory-efficient, especially for large datasets with many instances.

Comment on lines +169 to +232
def test_instance_update_skeleton():
skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=skel)

# No need to update on rename
skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
assert inst["X"].x == 0
assert inst["Y"].x == 1
assert inst["Z"].x == 2
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]])

# Remove a node from the skeleton
Y = skel["Y"]
skel.remove_node("Y")
assert Y not in skel

with pytest.raises(KeyError):
inst.numpy() # .numpy() breaks without update
assert Y in inst.points # and the points dict still has the old key
inst.update_skeleton()
assert Y not in inst.points # after update, the old key is gone
assert_equal(inst.numpy(), [[0, 0], [2, 2]])

# Reorder nodes
skel.reorder_nodes(["Z", "X"])
assert_equal(inst.numpy(), [[2, 2], [0, 0]]) # .numpy() works without update
assert (
list(inst.points.keys()) != skel.nodes
) # but the points dict still has the old order
inst.update_skeleton()
assert list(inst.points.keys()) == skel.nodes # after update, the order is correct


def test_instance_replace_skeleton():
# Full replacement
old_skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel)
new_skel = Skeleton(["X", "Y", "Z"])
inst.replace_skeleton(new_skel, node_map={"A": "X", "B": "Y", "C": "Z"})
assert inst.skeleton == new_skel
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]])
assert list(inst.points.keys()) == new_skel.nodes

# Partial replacement
old_skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel)
new_skel = Skeleton(["X", "C", "Y"])
inst.replace_skeleton(new_skel)
assert inst.skeleton == new_skel
assert_equal(inst.numpy(), [[np.nan, np.nan], [2, 2], [np.nan, np.nan]])
assert new_skel["C"] in inst.points
assert old_skel["A"] not in inst.points
assert old_skel["C"] not in inst.points

# Fast path with reverse node map
old_skel = Skeleton(["A", "B", "C"])
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel)
new_skel = Skeleton(["X", "Y", "Z"])
rev_node_map = {
new_node: old_node for new_node, old_node in zip(new_skel.nodes, old_skel.nodes)
}
inst.replace_skeleton(new_skel, rev_node_map=rev_node_map)
assert inst.skeleton == new_skel
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider using pytest fixtures to reduce setup duplication.

The new test functions share similar setup code for creating skeletons and instances. Consider extracting these into fixtures to improve test isolation and reduce duplication:

@pytest.fixture
def basic_skeleton():
    return Skeleton(["A", "B", "C"])

@pytest.fixture
def basic_instance(basic_skeleton):
    return Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=basic_skeleton)

def test_instance_update_skeleton(basic_skeleton, basic_instance):
    # Use fixtures instead of creating new objects
    skel = basic_skeleton
    inst = basic_instance
    ...

def test_instance_replace_skeleton(basic_skeleton, basic_instance):
    # Use fixtures here as well
    old_skel = basic_skeleton
    inst = basic_instance
    ...

Comment on lines +97 to +105
def _nodes_on_setattr(self, attr, new_nodes):
"""Callback to update caches when nodes are set."""
self.rebuild_cache(nodes=new_nodes)
return new_nodes

nodes: list[Node] = field(factory=list, on_setattr=_update_node_map)
nodes: list[Node] = field(
factory=list,
on_setattr=_nodes_on_setattr,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making the callback more robust.

The callback mechanism for node updates could benefit from additional validation.

 def _nodes_on_setattr(self, attr, new_nodes):
     """Callback to update caches when nodes are set."""
+    if not isinstance(new_nodes, list):
+        raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
+    if not all(isinstance(node, (Node, str)) for node in new_nodes):
+        raise TypeError("All nodes must be Node objects or strings")
     self.rebuild_cache(nodes=new_nodes)
     return new_nodes
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _nodes_on_setattr(self, attr, new_nodes):
"""Callback to update caches when nodes are set."""
self.rebuild_cache(nodes=new_nodes)
return new_nodes
nodes: list[Node] = field(factory=list, on_setattr=_update_node_map)
nodes: list[Node] = field(
factory=list,
on_setattr=_nodes_on_setattr,
)
def _nodes_on_setattr(self, attr, new_nodes):
"""Callback to update caches when nodes are set."""
if not isinstance(new_nodes, list):
raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
if not all(isinstance(node, (Node, str)) for node in new_nodes):
raise TypeError("All nodes must be Node objects or strings")
self.rebuild_cache(nodes=new_nodes)
return new_nodes
nodes: list[Node] = field(
factory=list,
on_setattr=_nodes_on_setattr,
)

@@ -1,6 +1,6 @@
"""Test methods and functions in the sleap_io.model.labels file."""

from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_allclose
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused import.

The assert_equal import from numpy.testing is not used in the code.

-from numpy.testing import assert_equal, assert_allclose
+from numpy.testing import assert_allclose
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from numpy.testing import assert_equal, assert_allclose
from numpy.testing import assert_allclose
🧰 Tools
🪛 Ruff

3-3: numpy.testing.assert_equal imported but unused

Remove unused import: numpy.testing.assert_equal

(F401)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 897752b and 1a80f11.

📒 Files selected for processing (1)
  • sleap_io/model/skeleton.py (6 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

10-10: typing imported but unused

Remove unused import: typing

(F401)


266-266: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


273-273: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


275-275: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


282-282: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


284-284: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


286-286: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


304-304: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


307-307: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


347-347: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


401-401: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


446-446: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


455-455: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


457-457: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

🔇 Additional comments (8)
sleap_io/model/skeleton.py (8)

74-77: LGTM!

The type alias NodeOrIndex is defined correctly using the | operator for Python 3.10+. The commented-out alternatives for Python 3.12+ and older versions are also provided for reference.


Line range hint 194-207: LGTM!

The properties edge_inds and symmetry_inds have been updated to use tuple[int, int] instead of Tuple[int, int], reflecting a shift to a more modern type hinting style.


301-315: LGTM!

The add_node method correctly checks if the node already exists in the skeleton and raises a ValueError if it does. It also handles the case where the input is a string by creating a new Node object. The internal caches are updated atomically after the node is added.

🧰 Tools
🪛 Ruff

304-304: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


307-307: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


316-324: LGTM!

The add_nodes method is a convenience method that calls add_node for each node in the input list. This ensures that the validation and cache updates are handled correctly for each node.


325-351: LGTM!

The require_node method is a useful utility that ensures a Node object is returned for various input types (Node, name, or index). It also handles adding missing nodes to the skeleton if add_missing is True, which is the default behavior.

🧰 Tools
🪛 Ruff

347-347: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


554-591: LGTM!

The reorder_nodes method correctly updates the order of nodes in the skeleton. It raises a ValueError if the new order is not the same length as the current nodes. The method also handles updating the internal caches after reordering.

The warning about instances not being automatically updated is valid and the recommendation to use Labels.reorder_nodes() or manually update instances is appropriate.


99-107: 🛠️ Refactor suggestion

Consider making the callback more robust.

The callback mechanism for node updates could benefit from additional validation.

 def _nodes_on_setattr(self, attr, new_nodes):
     """Callback to update caches when nodes are set."""
+    if not isinstance(new_nodes, list):
+        raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
+    if not all(isinstance(node, Node) for node in new_nodes):
+        raise TypeError("All nodes must be Node objects")
     self.rebuild_cache(nodes=new_nodes)
     return new_nodes

Likely invalid or redundant comment.


163-187: 🛠️ Refactor suggestion

Consider adding error handling for duplicate node names.

While the cache rebuilding logic is solid, it might silently overwrite entries if there are duplicate node names.

 def rebuild_cache(self, nodes: list[Node] | None = None):
     if nodes is None:
         nodes = self.nodes
+    # Check for duplicate names
+    name_counts = {}
+    for node in nodes:
+        name_counts[node.name] = name_counts.get(node.name, 0) + 1
+        if name_counts[node.name] > 1:
+            raise ValueError(f"Duplicate node name found: {node.name}")
     self._name_to_node_cache = {node.name: node for node in nodes}
     self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}

Likely invalid or redundant comment.

@@ -7,7 +7,7 @@

from __future__ import annotations
from attrs import define, field
from typing import Optional, Tuple, Union
import typing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove unused import.

The typing module is imported but not used in the code.

-import typing
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import typing
🧰 Tools
🪛 Ruff

10-10: typing imported but unused

Remove unused import: typing

(F401)

Comment on lines +480 to +553
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.

Args:
nodes: A list of node names, indices, or `Node` objects to remove.

Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.

Any edges and symmetries that are connected to the removed nodes will also
be removed.

Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.

It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.

To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]

# Remove nodes from the skeleton.
for node in rm_node_objs:
self.nodes.remove(node)
del self._name_to_node_cache[node.name]

# Remove edges connected to the removed nodes.
self.edges = [
edge
for edge in self.edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]

# Remove symmetries connected to the removed nodes.
self.symmetries = [
symmetry
for symmetry in self.symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]

# Update node index map.
self.rebuild_cache()

def remove_node(self, node: NodeOrIndex):
"""Remove a single node from the skeleton.

Args:
node: The node to remove. Can be specified as a string name, integer index,
or `Node` object.

Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.

Any edges and symmetries that are connected to the removed node will also be
removed.

Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.

It is recommended to use the `Labels.remove_nodes()` method which will
update all contained instances to reflect the changes made to the skeleton.

To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making node removal atomic.

While the node removal implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state.

 def remove_nodes(self, nodes: list[NodeOrIndex]):
     # Standardize input and make a pre-mutation copy before keys are changed.
     rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
+    
+    # Create copies of current state
+    new_nodes = self.nodes.copy()
+    new_edges = self.edges.copy()
+    new_symmetries = self.symmetries.copy()
+    new_name_cache = self._name_to_node_cache.copy()
+    
+    try:
+        # Remove nodes
+        for node in rm_node_objs:
+            new_nodes.remove(node)
+            del new_name_cache[node.name]
+        
+        # Update edges and symmetries
+        new_edges = [
+            edge for edge in new_edges
+            if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
+        ]
+        new_symmetries = [
+            symmetry for symmetry in new_symmetries
+            if symmetry.nodes.isdisjoint(rm_node_objs)
+        ]
+        
+        # Commit changes
+        self.nodes = new_nodes
+        self.edges = new_edges
+        self.symmetries = new_symmetries
+        self._name_to_node_cache = new_name_cache
+        self.rebuild_cache()
+    except Exception as e:
+        # Rollback on error
+        raise ValueError(f"Failed to remove nodes: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
# Remove nodes from the skeleton.
for node in rm_node_objs:
self.nodes.remove(node)
del self._name_to_node_cache[node.name]
# Remove edges connected to the removed nodes.
self.edges = [
edge
for edge in self.edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]
# Remove symmetries connected to the removed nodes.
self.symmetries = [
symmetry
for symmetry in self.symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]
# Update node index map.
self.rebuild_cache()
def remove_node(self, node: NodeOrIndex):
"""Remove a single node from the skeleton.
Args:
node: The node to remove. Can be specified as a string name, integer index,
or `Node` object.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed node will also be
removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained instances to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
# Create copies of current state
new_nodes = self.nodes.copy()
new_edges = self.edges.copy()
new_symmetries = self.symmetries.copy()
new_name_cache = self._name_to_node_cache.copy()
try:
# Remove nodes
for node in rm_node_objs:
new_nodes.remove(node)
del new_name_cache[node.name]
# Update edges and symmetries
new_edges = [
edge for edge in new_edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]
new_symmetries = [
symmetry for symmetry in new_symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]
# Commit changes
self.nodes = new_nodes
self.edges = new_edges
self.symmetries = new_symmetries
self._name_to_node_cache = new_name_cache
self.rebuild_cache()
except Exception as e:
# Rollback on error
raise ValueError(f"Failed to remove nodes: {str(e)}")
def remove_node(self, node: NodeOrIndex):
"""Remove a single node from the skeleton.
Args:
node: The node to remove. Can be specified as a string name, integer index,
or `Node` object.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed node will also be
removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained instances to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`Instance.update_skeleton()` on each instance that uses this skeleton.
"""
self.remove_nodes([node])

Comment on lines +414 to +479
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
"""Rename nodes in the skeleton.

Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.

If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.

Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.

Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.

After renaming, instances using this skeleton **do NOT need to be updated**
as the nodes are stored by reference in the skeleton, so changes are
reflected automatically.

Example:
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> skel.node_names
["X", "Y", "Z"]
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]
"""
if type(name_map) == list:
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}

for old_name, new_name in name_map.items():
if type(old_name) == Node:
old_name = old_name.name
if type(old_name) == int:
old_name = self.nodes[old_name].name

if old_name not in self._name_to_node_cache:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._name_to_node_cache:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")

node = self._name_to_node_cache[old_name]
node.name = new_name
self._name_to_node_cache[new_name] = node
del self._name_to_node_cache[old_name]

def rename_node(self, old_name: NodeOrIndex, new_name: str):
"""Rename a single node in the skeleton.

Args:
old_name: The name of the node to rename. Can also be specified as an
integer index or `Node` object.
new_name: The new name for the node.
"""
self.rename_nodes({old_name: new_name})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making node renaming atomic.

While the implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state.

 def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
+    # Validate and prepare all changes first
+    updates = {}
     if isinstance(name_map, list):
         if len(name_map) != len(self.nodes):
             raise ValueError(
                 "List of new node names must be the same length as the current nodes."
             )
         name_map = {node: name for node, name in zip(self.nodes, name_map)}

     for old_name, new_name in name_map.items():
         if isinstance(old_name, Node):
             old_name = old_name.name
         if isinstance(old_name, int):
             old_name = self.nodes[old_name].name

         if old_name not in self._name_to_node_cache:
             raise ValueError(f"Node '{old_name}' not found in the skeleton.")
         if new_name in self._name_to_node_cache:
             raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
+        updates[old_name] = (self._name_to_node_cache[old_name], new_name)

+    # Apply all changes after validation
+    for old_name, (node, new_name) in updates.items():
         node = self._name_to_node_cache[old_name]
         node.name = new_name
         self._name_to_node_cache[new_name] = node
         del self._name_to_node_cache[old_name]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
"""Rename nodes in the skeleton.
Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.
If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.
Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.
After renaming, instances using this skeleton **do NOT need to be updated**
as the nodes are stored by reference in the skeleton, so changes are
reflected automatically.
Example:
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> skel.node_names
["X", "Y", "Z"]
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]
"""
if type(name_map) == list:
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}
for old_name, new_name in name_map.items():
if type(old_name) == Node:
old_name = old_name.name
if type(old_name) == int:
old_name = self.nodes[old_name].name
if old_name not in self._name_to_node_cache:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._name_to_node_cache:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
node = self._name_to_node_cache[old_name]
node.name = new_name
self._name_to_node_cache[new_name] = node
del self._name_to_node_cache[old_name]
def rename_node(self, old_name: NodeOrIndex, new_name: str):
"""Rename a single node in the skeleton.
Args:
old_name: The name of the node to rename. Can also be specified as an
integer index or `Node` object.
new_name: The new name for the node.
"""
self.rename_nodes({old_name: new_name})
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
"""Rename nodes in the skeleton.
Args:
name_map: A dictionary mapping old node names to new node names. Keys can be
specified as `Node` objects, integer indices, or string names. Values
must be specified as string names.
If a list of strings is provided of the same length as the current
nodes, the nodes will be renamed to the names in the list in order.
Raises:
ValueError: If the new node names exist in the skeleton or if the old node
names are not found in the skeleton.
Notes:
This method should always be used when renaming nodes in the skeleton as it
handles updating the lookup caches necessary for indexing nodes by name.
After renaming, instances using this skeleton **do NOT need to be updated**
as the nodes are stored by reference in the skeleton, so changes are
reflected automatically.
Example:
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
>>> skel.node_names
["X", "Y", "Z"]
>>> skel.rename_nodes(["a", "b", "c"])
>>> skel.node_names
["a", "b", "c"]
"""
# Validate and prepare all changes first
updates = {}
if isinstance(name_map, list):
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}
for old_name, new_name in name_map.items():
if isinstance(old_name, Node):
old_name = old_name.name
if isinstance(old_name, int):
old_name = self.nodes[old_name].name
if old_name not in self._name_to_node_cache:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._name_to_node_cache:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
updates[old_name] = (self._name_to_node_cache[old_name], new_name)
# Apply all changes after validation
for old_name, (node, new_name) in updates.items():
node = self._name_to_node_cache[old_name]
node.name = new_name
self._name_to_node_cache[new_name] = node
del self._name_to_node_cache[old_name]
def rename_node(self, old_name: NodeOrIndex, new_name: str):
"""Rename a single node in the skeleton.
Args:
old_name: The name of the node to rename. Can also be specified as an
integer index or `Node` object.
new_name: The new name for the node.
"""
self.rename_nodes({old_name: new_name})
🧰 Tools
🪛 Ruff

446-446: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


455-455: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


457-457: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Outside diff range and nitpick comments (1)
sleap_io/model/labels.py (1)

468-512: Add validation for node name uniqueness and list length.

While the implementation is solid, consider adding these validations:

  1. Ensure new node names are unique when using dict input
  2. Validate that the list length matches the node count when using list input
 def rename_nodes(
     self,
     name_map: dict[NodeOrIndex, str] | list[str],
     skeleton: Skeleton | None = None,
 ):
     if skeleton is None:
         if len(self.skeletons) != 1:
             raise ValueError(
                 "Skeleton must be specified when there is more than one skeleton in "
                 "the labels."
             )
         skeleton = self.skeleton

+    # Validate input
+    if isinstance(name_map, list):
+        if len(name_map) != len(skeleton.nodes):
+            raise ValueError(
+                f"Expected {len(skeleton.nodes)} names but got {len(name_map)}."
+            )
+    else:
+        new_names = set(name_map.values())
+        if len(new_names) != len(name_map):
+            raise ValueError("New node names must be unique.")

     skeleton.rename_nodes(name_map)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between 1a80f11 and 57ee01c.

📒 Files selected for processing (2)
  • sleap_io/model/labels.py (2 hunks)
  • sleap_io/model/skeleton.py (6 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py

266-266: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


273-273: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


275-275: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


282-282: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


284-284: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


286-286: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


304-304: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


307-307: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


347-347: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


401-401: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


446-446: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


455-455: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


457-457: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

🔇 Additional comments (4)
sleap_io/model/skeleton.py (2)

99-107: 🛠️ Refactor suggestion

Add type validation in the callback.

The callback mechanism for node updates could benefit from additional validation to ensure type safety.

 def _nodes_on_setattr(self, attr, new_nodes):
     """Callback to update caches when nodes are set."""
+    if not isinstance(new_nodes, list):
+        raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
+    if not all(isinstance(node, (Node, str)) for node in new_nodes):
+        raise TypeError("All nodes must be Node objects or strings")
     self.rebuild_cache(nodes=new_nodes)
     return new_nodes

Likely invalid or redundant comment.


10-10: ⚠️ Potential issue

Remove unused import.

The typing module is imported but not used directly in the code.

-import typing

Likely invalid or redundant comment.

sleap_io/model/labels.py (2)

463-466: LGTM! Efficient implementation of instances property.

The generator expression implementation is memory-efficient for iterating over instances.


513-546: LGTM! Well-implemented node removal with proper instance updates.

The method correctly handles:

  • Skeleton validation
  • Instance updates
  • Documentation of side effects

Comment on lines +271 to +289
def __getitem__(self, idx: NodeOrIndex) -> Node:
"""Return a `Node` when indexing by name or integer."""
if type(idx) == int:
return self.nodes[idx]
elif type(idx) == str:
return self._node_name_map[idx]
return self._name_to_node_cache[idx]
else:
raise IndexError(f"Invalid indexing argument for skeleton: {idx}")

def __contains__(self, node: NodeOrIndex) -> bool:
"""Check if a node is in the skeleton."""
if type(node) == str:
return node in self._name_to_node_cache
elif type(node) == Node:
return node in self.nodes
elif type(node) == int:
return 0 <= node < len(self.nodes)
else:
raise ValueError(f"Invalid node type for skeleton: {node}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Replace type() comparisons with isinstance() checks.

Multiple instances of type() == comparisons should be replaced with isinstance() checks to properly handle inheritance and follow Python best practices.

Apply these changes throughout the file:

-    if type(idx) == int:
+    if isinstance(idx, int):
         return self.nodes[idx]
-    elif type(idx) == str:
+    elif isinstance(idx, str):
         return self._name_to_node_cache[idx]

-    if type(node) == str:
+    if isinstance(node, str):
         return node in self._name_to_node_cache
-    elif type(node) == Node:
+    elif isinstance(node, Node):
         return node in self.nodes
-    elif type(node) == int:
+    elif isinstance(node, int):
         return 0 <= node < len(self.nodes)

-    if type(src) == tuple:
+    if isinstance(src, tuple):
         src, dst = src

-    if type(name_map) == list:
+    if isinstance(name_map, list):
         if len(name_map) != len(self.nodes):
             raise ValueError(
                 "List of new node names must be the same length as the current "
                 "nodes."
             )

Also applies to: 364-374, 446-458

🧰 Tools
🪛 Ruff

273-273: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


275-275: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


282-282: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


284-284: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


286-286: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Comment on lines +74 to +76
NodeOrIndex = typing.Union[Node, str, int]
# NodeOrIndex: TypeAlias = Node | str | int # py >= 3.10
# type NodeOrIndex = Node | str | int # py >= 3.12
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider using more modern type alias syntax.

The code includes commented alternatives for newer Python versions. Since the code supports Python 3.10+, you could use the more modern syntax.

-NodeOrIndex = typing.Union[Node, str, int]
+NodeOrIndex: TypeAlias = Node | str | int
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
NodeOrIndex = typing.Union[Node, str, int]
# NodeOrIndex: TypeAlias = Node | str | int # py >= 3.10
# type NodeOrIndex = Node | str | int # py >= 3.12
NodeOrIndex: TypeAlias = Node | str | int
# NodeOrIndex: TypeAlias = Node | str | int # py >= 3.10
# type NodeOrIndex = Node | str | int # py >= 3.12

Comment on lines +480 to +527
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.

Args:
nodes: A list of node names, indices, or `Node` objects to remove.

Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.

Any edges and symmetries that are connected to the removed nodes will also
be removed.

Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.

It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.

To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]

# Remove nodes from the skeleton.
for node in rm_node_objs:
self.nodes.remove(node)
del self._name_to_node_cache[node.name]

# Remove edges connected to the removed nodes.
self.edges = [
edge
for edge in self.edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]

# Remove symmetries connected to the removed nodes.
self.symmetries = [
symmetry
for symmetry in self.symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]

# Update node index map.
self.rebuild_cache()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider making node removal atomic.

While the node removal implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state. Consider implementing an atomic operation pattern.

 def remove_nodes(self, nodes: list[NodeOrIndex]):
     # Standardize input and make a pre-mutation copy before keys are changed.
     rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
+    
+    # Create copies of current state
+    new_nodes = self.nodes.copy()
+    new_edges = self.edges.copy()
+    new_symmetries = self.symmetries.copy()
+    new_name_cache = self._name_to_node_cache.copy()
+    
+    try:
+        # Remove nodes
+        for node in rm_node_objs:
+            new_nodes.remove(node)
+            del new_name_cache[node.name]
+        
+        # Update edges and symmetries
+        new_edges = [
+            edge for edge in new_edges
+            if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
+        ]
+        new_symmetries = [
+            symmetry for symmetry in new_symmetries
+            if symmetry.nodes.isdisjoint(rm_node_objs)
+        ]
+        
+        # Commit changes
+        self.nodes = new_nodes
+        self.edges = new_edges
+        self.symmetries = new_symmetries
+        self._name_to_node_cache = new_name_cache
+        self.rebuild_cache()
+    except Exception as e:
+        # Rollback on error
+        raise ValueError(f"Failed to remove nodes: {str(e)}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
# Remove nodes from the skeleton.
for node in rm_node_objs:
self.nodes.remove(node)
del self._name_to_node_cache[node.name]
# Remove edges connected to the removed nodes.
self.edges = [
edge
for edge in self.edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]
# Remove symmetries connected to the removed nodes.
self.symmetries = [
symmetry
for symmetry in self.symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]
# Update node index map.
self.rebuild_cache()
def remove_nodes(self, nodes: list[NodeOrIndex]):
"""Remove nodes from the skeleton.
Args:
nodes: A list of node names, indices, or `Node` objects to remove.
Notes:
This method handles updating the lookup caches necessary for indexing nodes
by name.
Any edges and symmetries that are connected to the removed nodes will also
be removed.
Warning:
**This method does NOT update instances** that use this skeleton to reflect
changes.
It is recommended to use the `Labels.remove_nodes()` method which will
update all contained to reflect the changes made to the skeleton.
To manually update instances after this method is called, call
`instance.update_nodes()` on each instance that uses this skeleton.
"""
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
# Create copies of current state
new_nodes = self.nodes.copy()
new_edges = self.edges.copy()
new_symmetries = self.symmetries.copy()
new_name_cache = self._name_to_node_cache.copy()
try:
# Remove nodes
for node in rm_node_objs:
new_nodes.remove(node)
del new_name_cache[node.name]
# Update edges and symmetries
new_edges = [
edge for edge in new_edges
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
]
new_symmetries = [
symmetry for symmetry in new_symmetries
if symmetry.nodes.isdisjoint(rm_node_objs)
]
# Commit changes
self.nodes = new_nodes
self.edges = new_edges
self.symmetries = new_symmetries
self._name_to_node_cache = new_name_cache
self.rebuild_cache()
except Exception as e:
# Rollback on error
raise ValueError(f"Failed to remove nodes: {str(e)}")

Comment on lines 352 to 389
def add_edge(
self,
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex],
dst: NodeOrIndex | None = None,
):
"""Add an `Edge` to the skeleton.

Args:
src: The source `Node` or name of the source node.
dst: The destination `Node` or name of the destination node.
src: The source node specified as a `Node`, name or index.
dst: The destination node specified as a `Node`, name or index.
"""
edge = None
if type(src) == tuple:
src, dst = src

if isinstance(src, NodeOrIndex):
if not isinstance(dst, NodeOrIndex):
raise ValueError("Destination node must be specified.")

src = self.require_node(src)
dst = self.require_node(dst)
edge = Edge(src, dst)

if type(src) == Edge:
edge = src
if edge not in self.edges:
self.edges.append(edge)
if edge.source not in self.nodes:
self.add_node(edge.source)
if edge.destination not in self.nodes:
self.add_node(edge.destination)
return

if type(src) == str or type(src) == Node:
try:
src = self.index(src)
except KeyError:
self.add_node(src)
src = self.index(src)

if type(dst) == str or type(dst) == Node:
try:
dst = self.index(dst)
except KeyError:
self.add_node(dst)
dst = self.index(dst)

edge = Edge(self.nodes[src], self.nodes[dst])

if edge not in self.edges:
self.edges.append(edge)

def add_edges(self, edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]):
"""Add multiple `Edge`s to the skeleton.

Args:
edges: A list of `Edge` objects or 2-tuples of source and destination nodes.
"""
for edge in edges:
self.add_edge(edge)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding cycle detection in edge addition.

While the edge management is well-implemented, adding cycle detection would prevent creation of cyclic dependencies in the skeleton graph.

 def add_edge(
     self,
     src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex],
     dst: NodeOrIndex | None = None,
 ):
     edge = None
     if isinstance(src, tuple):
         src, dst = src

     if isinstance(src, NodeOrIndex):
         if not isinstance(dst, NodeOrIndex):
             raise ValueError("Destination node must be specified.")

         src = self.require_node(src)
         dst = self.require_node(dst)
+        # Check for cycles
+        if self._would_create_cycle(src, dst):
+            raise ValueError("Adding this edge would create a cycle in the skeleton graph")
         edge = Edge(src, dst)

     if isinstance(src, Edge):
         edge = src

     if edge not in self.edges:
         self.edges.append(edge)

+def _would_create_cycle(self, src: Node, dst: Node) -> bool:
+    """Check if adding an edge would create a cycle."""
+    visited = set()
+    def dfs(node: Node) -> bool:
+        if node == src:
+            return True
+        visited.add(node)
+        for edge in self.edges:
+            if edge.source == node and edge.destination not in visited:
+                if dfs(edge.destination):
+                    return True
+        return False
+    return dfs(dst)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def add_edge(
self,
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex],
dst: NodeOrIndex | None = None,
):
"""Add an `Edge` to the skeleton.
Args:
src: The source `Node` or name of the source node.
dst: The destination `Node` or name of the destination node.
src: The source node specified as a `Node`, name or index.
dst: The destination node specified as a `Node`, name or index.
"""
edge = None
if type(src) == tuple:
src, dst = src
if isinstance(src, NodeOrIndex):
if not isinstance(dst, NodeOrIndex):
raise ValueError("Destination node must be specified.")
src = self.require_node(src)
dst = self.require_node(dst)
edge = Edge(src, dst)
if type(src) == Edge:
edge = src
if edge not in self.edges:
self.edges.append(edge)
if edge.source not in self.nodes:
self.add_node(edge.source)
if edge.destination not in self.nodes:
self.add_node(edge.destination)
return
if type(src) == str or type(src) == Node:
try:
src = self.index(src)
except KeyError:
self.add_node(src)
src = self.index(src)
if type(dst) == str or type(dst) == Node:
try:
dst = self.index(dst)
except KeyError:
self.add_node(dst)
dst = self.index(dst)
edge = Edge(self.nodes[src], self.nodes[dst])
if edge not in self.edges:
self.edges.append(edge)
def add_edges(self, edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]):
"""Add multiple `Edge`s to the skeleton.
Args:
edges: A list of `Edge` objects or 2-tuples of source and destination nodes.
"""
for edge in edges:
self.add_edge(edge)
def add_edge(
self,
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex],
dst: NodeOrIndex | None = None,
):
"""Add an `Edge` to the skeleton.
Args:
src: The source node specified as a `Node`, name or index.
dst: The destination node specified as a `Node`, name or index.
"""
edge = None
if isinstance(src, tuple):
src, dst = src
if isinstance(src, NodeOrIndex):
if not isinstance(dst, NodeOrIndex):
raise ValueError("Destination node must be specified.")
src = self.require_node(src)
dst = self.require_node(dst)
# Check for cycles
if self._would_create_cycle(src, dst):
raise ValueError("Adding this edge would create a cycle in the skeleton graph")
edge = Edge(src, dst)
if isinstance(src, Edge):
edge = src
if edge not in self.edges:
self.edges.append(edge)
def _would_create_cycle(self, src: Node, dst: Node) -> bool:
"""Check if adding an edge would create a cycle."""
visited = set()
def dfs(node: Node) -> bool:
if node == src:
return True
visited.add(node)
for edge in self.edges:
if edge.source == node and edge.destination not in visited:
if dfs(edge.destination):
return True
return False
return dfs(dst)
def add_edges(self, edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]):
"""Add multiple `Edge`s to the skeleton.
Args:
edges: A list of `Edge` objects or 2-tuples of source and destination nodes.
"""
for edge in edges:
self.add_edge(edge)
🧰 Tools
🪛 Ruff

364-364: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)


375-375: Use is and is not for type comparisons, or isinstance() for isinstance checks

(E721)

Comment on lines +618 to +625
if node_map is None:
node_map = {}
for old_node in old_skeleton.nodes:
for new_node in new_skeleton.nodes:
if old_node.name == new_node.name:
node_map[old_node] = new_node
break
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Optimize node mapping creation performance.

The current implementation uses nested loops to create the node mapping, which has O(n²) complexity. Consider using a dictionary for O(n) lookup:

 if node_map is None:
     node_map = {}
+    new_node_map = {node.name: node for node in new_skeleton.nodes}
     for old_node in old_skeleton.nodes:
-        for new_node in new_skeleton.nodes:
-            if old_node.name == new_node.name:
-                node_map[old_node] = new_node
-                break
+        if old_node.name in new_node_map:
+            node_map[old_node] = new_node_map[old_node.name]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if node_map is None:
node_map = {}
for old_node in old_skeleton.nodes:
for new_node in new_skeleton.nodes:
if old_node.name == new_node.name:
node_map[old_node] = new_node
break
else:
if node_map is None:
node_map = {}
new_node_map = {node.name: node for node in new_skeleton.nodes}
for old_node in old_skeleton.nodes:
if old_node.name in new_node_map:
node_map[old_node] = new_node_map[old_node.name]
else:

@talmo talmo merged commit 780deb0 into main Nov 1, 2024
9 checks passed
@talmo talmo deleted the talmo/skeleton-mutation-qol branch November 1, 2024 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant