-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add skeleton manipulation utilities #136
Conversation
WalkthroughThe changes introduce significant enhancements to the Changes
Possibly related PRs
Poem
📜 Recent review detailsConfiguration used: CodeRabbit UI 📒 Files selected for processing (1)
🧰 Additional context used🪛 Ruffsleap_io/model/skeleton.py280-280: Use (E721) 287-287: Use (E721) 289-289: Use (E721) 296-296: Use (E721) 298-298: Use (E721) 300-300: Use (E721) 318-318: Use (E721) 321-321: Use (E721) 361-361: Use (E721) 378-378: Use (E721) 389-389: Use (E721) 415-415: Use (E721) 460-460: Use (E721) 469-469: Use (E721) 471-471: Use (E721) 🔇 Additional comments (14)sleap_io/model/skeleton.py (14)
Defining
Using
The
Initializing
Calling
The
The return type hint
Consistently using
The 🧰 Tools🪛 Ruff287-287: Use (E721) 289-289: Use (E721) 296-296: Use (E721) 298-298: Use (E721) 300-300: Use (E721)
The 🧰 Tools🪛 Ruff361-361: Use (E721)
Accepting various input types for 🧰 Tools🪛 Ruff378-378: Use (E721) 389-389: Use (E721)
The 🧰 Tools🪛 Ruff460-460: Use (E721) 469-469: Use (E721) 471-471: Use (E721)
Removing nodes and associated edges and symmetries while updating caches maintains integrity.
Reordering nodes enhances control over node organization while keeping caches consistent. Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Outside diff range and nitpick comments (1)
sleap_io/model/skeleton.py (1)
333-363
: Consider enhancing type hints and documentation.The documentation is comprehensive, but here are some suggestions for improvement:
- Simplify the type hint using the union operator (Python 3.10+):
-def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]): +def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
- Add a note about the behavior of node references in the docstring:
Notes: This method should always be used when renaming nodes in the skeleton as it handles updating the lookup caches necessary for indexing nodes by name. After renaming, instances using this skeleton do NOT need to be updated as the nodes are stored by reference in the skeleton. This means that any existing Edge or Symmetry objects will continue to work correctly with the renamed nodes.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
sleap_io/model/skeleton.py
(1 hunks)tests/model/test_skeleton.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
🔇 Additional comments (2)
tests/model/test_skeleton.py (1)
158-171
: LGTM! Well-structured test implementation.
The test function follows good testing practices with clear arrange-act-assert patterns and proper error case coverage.
sleap_io/model/skeleton.py (1)
333-386
: Verify the impact on existing instances.
The implementation looks correct, but let's verify that node references are preserved.
✅ Verification successful
Based on the search results, I can now generate the final response:
Node references are safely maintained across the codebase
The verification confirms that node name modifications are properly encapsulated:
- The only direct
node.name
assignment is within therename_nodes
method itself - All other occurrences are read-only assertions in test files or equality comparisons
- No other methods were found that modify node names
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Description: Check if there are any direct node.name assignments elsewhere in the codebase
# that might bypass this method and break the node map.
# Search for direct name assignments
echo "Searching for direct node.name assignments:"
rg "\.name\s*=" --type py
# Search for other methods that might modify node names
echo -e "\nSearching for methods that might modify node names:"
ast-grep --pattern 'def $FUNC($ARGS) { $$$ $NODE.name $$$}'
Length of output: 1374
🧰 Tools
🪛 Ruff
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
tests/model/test_skeleton.py
Outdated
def test_rename_nodes(): | ||
"""Test renaming nodes in the skeleton.""" | ||
skel = Skeleton(["A", "B", "C"]) | ||
skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | ||
assert skel.node_names == ["X", "Y", "Z"] | ||
|
||
skel.rename_nodes(["a", "b", "c"]) | ||
assert skel.node_names == ["a", "b", "c"] | ||
|
||
with pytest.raises(ValueError): | ||
skel.rename_nodes({"a": "b"}) | ||
|
||
with pytest.raises(ValueError): | ||
skel.rename_nodes({"d": "e"}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding more test cases for comprehensive coverage.
While the current test cases cover the basic functionality, consider adding tests for:
- Verifying that edges and symmetries are preserved after renaming
- Testing partial renaming with dictionary (only renaming some nodes)
- Ensuring node indices remain consistent after renaming
Example test cases to add:
def test_rename_nodes_preserves_structure():
"""Test that renaming preserves edges and symmetries."""
# Test edges preservation
skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")])
original_edges = skel.edge_inds.copy()
skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"})
assert skel.edge_inds == original_edges
assert skel.edge_names == [("X", "Y"), ("Y", "Z")]
# Test symmetries preservation
skel = Skeleton(["L", "R", "C"])
skel.add_symmetry("L", "R")
original_symmetries = skel.symmetry_inds.copy()
skel.rename_nodes(["Left", "Right", "Center"])
assert skel.symmetry_inds == original_symmetries
assert skel.symmetry_names == [("Left", "Right")]
def test_partial_rename():
"""Test renaming only some nodes."""
skel = Skeleton(["A", "B", "C"])
skel.rename_nodes({"A": "X", "C": "Z"}) # B remains unchanged
assert skel.node_names == ["X", "B", "Z"]
sleap_io/model/skeleton.py
Outdated
for old_name, new_name in name_map.items(): | ||
if type(old_name) == Node: | ||
old_name = old_name.name | ||
if type(new_name) == int: | ||
new_name = self.nodes[new_name].name | ||
|
||
if old_name not in self._node_name_map: | ||
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | ||
if new_name in self._node_name_map: | ||
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | ||
|
||
node = self._node_name_map[old_name] | ||
node.name = new_name | ||
self._node_name_map[new_name] = node | ||
del self._node_name_map[old_name] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve type checking and make node map updates safer.
The node map updates could be made safer and more maintainable:
+ # Validate input types and names first
+ updates = {}
for old_name, new_name in name_map.items():
- if type(old_name) == Node:
+ if isinstance(old_name, Node):
old_name = old_name.name
- if type(new_name) == int:
+ if isinstance(new_name, int):
new_name = self.nodes[new_name].name
+ if not isinstance(new_name, str):
+ raise ValueError(f"New name must be a string, got {type(new_name)}")
if old_name not in self._node_name_map:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._node_name_map:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
+ updates[old_name] = (self._node_name_map[old_name], new_name)
+ # Apply updates after validation
+ for old_name, (node, new_name) in updates.items():
- node = self._node_name_map[old_name]
node.name = new_name
self._node_name_map[new_name] = node
del self._node_name_map[old_name]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
for old_name, new_name in name_map.items(): | |
if type(old_name) == Node: | |
old_name = old_name.name | |
if type(new_name) == int: | |
new_name = self.nodes[new_name].name | |
if old_name not in self._node_name_map: | |
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | |
if new_name in self._node_name_map: | |
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | |
node = self._node_name_map[old_name] | |
node.name = new_name | |
self._node_name_map[new_name] = node | |
del self._node_name_map[old_name] | |
# Validate input types and names first | |
updates = {} | |
for old_name, new_name in name_map.items(): | |
if isinstance(old_name, Node): | |
old_name = old_name.name | |
if isinstance(new_name, int): | |
new_name = self.nodes[new_name].name | |
if not isinstance(new_name, str): | |
raise ValueError(f"New name must be a string, got {type(new_name)}") | |
if old_name not in self._node_name_map: | |
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | |
if new_name in self._node_name_map: | |
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | |
updates[old_name] = (self._node_name_map[old_name], new_name) | |
# Apply updates after validation | |
for old_name, (node, new_name) in updates.items(): | |
node.name = new_name | |
self._node_name_map[new_name] = node | |
del self._node_name_map[old_name] |
🧰 Tools
🪛 Ruff
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #136 +/- ##
==========================================
+ Coverage 96.13% 96.33% +0.19%
==========================================
Files 17 17
Lines 2124 2237 +113
==========================================
+ Hits 2042 2155 +113
Misses 82 82 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Outside diff range and nitpick comments (4)
tests/model/test_skeleton.py (2)
177-193
: LGTM, consider adding edge case tests.The test covers all node identification methods and basic error handling well. Consider adding tests for:
- Edge cases (renaming first/last node)
- Special characters in node names
- Verification that non-target nodes remain unchanged
156-158
: Maintain consistent spacing between test functions.To follow PEP 8 and maintain consistency with the rest of the file, use two blank lines between test functions.
def test_add_symmetry(): # ... existing code ... - - def test_rename_nodes(): # ... existing code ... - - def test_rename_node():Also applies to: 176-177
tests/model/test_labels.py (1)
638-667
: LGTM! Consider adding edge cases.The test implementation looks good and follows the established patterns. It effectively tests the basic functionality of the
instances
property.Consider adding these edge cases to make the test more comprehensive:
- Empty labeled frames (no instances)
- Predicted instances vs. regular instances
- Instance filtering (e.g., by track, skeleton)
def test_labels_instances(): labels = Labels() # Test empty labels assert len(list(labels.instances)) == 0 # Test empty labeled frame labels.append(LabeledFrame(video=Video("test.mp4"), frame_idx=0)) assert len(list(labels.instances)) == 0 # Test regular vs predicted instances labels.append( LabeledFrame( video=Video("test.mp4"), frame_idx=1, instances=[ Instance.from_numpy( np.array([[0, 1], [2, 3]]), skeleton=Skeleton(["A", "B"]) ) ], predicted_instances=[ PredictedInstance.from_numpy( np.array([[4, 5], [6, 7]]), skeleton=Skeleton(["A", "B"]) ) ] ) ) assert len(list(labels.instances)) == 1 # Should only count regular instances # Test instance filtering track = Track("test_track") labels.append( LabeledFrame( video=labels.video, frame_idx=2, instances=[ Instance.from_numpy( np.array([[0, 1], [2, 3]]), skeleton=labels.skeleton, track=track ) ] ) ) # Add assertions for filtering by track, skeleton, etc.sleap_io/model/labels.py (1)
462-465
: LGTM! Well-implemented property using memory-efficient generator expression.The implementation efficiently provides access to all instances across labeled frames using a generator expression, which is memory-efficient for large datasets as it doesn't materialize the entire list at once.
The use of a generator expression here is particularly beneficial when dealing with large datasets, as it allows for streaming access to instances without loading all of them into memory simultaneously. This approach supports efficient iteration and memory usage, especially important for processing large collections of labeled frames.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
sleap_io/model/labels.py
(2 hunks)sleap_io/model/skeleton.py
(1 hunks)tests/model/test_labels.py
(1 hunks)tests/model/test_skeleton.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
🔇 Additional comments (3)
tests/model/test_skeleton.py (1)
158-175
: LGTM, but test coverage could be expanded.
The test cases cover the basic functionality and error handling well. However, as mentioned in the previous review, consider adding tests for:
- Edge preservation after renaming
- Symmetry preservation
- Partial renaming scenarios
- Node indices consistency
sleap_io/model/skeleton.py (1)
388-396
: LGTM!
The implementation is clean and effectively reuses the rename_nodes
method, maintaining consistent behavior across both methods.
sleap_io/model/labels.py (1)
24-24
: LGTM! Appropriate addition of Iterator type hint import.
The import of Iterator
from typing
is correctly added to support type hinting of the new instances
property.
sleap_io/model/skeleton.py
Outdated
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]): | ||
"""Rename nodes in the skeleton. | ||
|
||
Args: | ||
name_map: A dictionary mapping old node names to new node names. Keys can be | ||
specified as `Node` objects, integer indices, or string names. Values | ||
must be specified as string names. | ||
|
||
If a list of strings is provided of the same length as the current | ||
nodes, the nodes will be renamed to the names in the list in order. | ||
|
||
Notes: | ||
This method should always be used when renaming nodes in the skeleton as it | ||
handles updating the lookup caches necessary for indexing nodes by name. | ||
|
||
After renaming, instances using this skeleton do NOT need to be updated as | ||
the nodes are stored by reference in the skeleton. | ||
|
||
Example: | ||
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")]) | ||
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | ||
>>> skel.node_names | ||
["X", "Y", "Z"] | ||
>>> skel.rename_nodes(["a", "b", "c"]) | ||
>>> skel.node_names | ||
["a", "b", "c"] | ||
|
||
Raises: | ||
ValueError: If the new node names exist in the skeleton or if the old node | ||
names are not found in the skeleton. | ||
""" | ||
if type(name_map) == list: | ||
if len(name_map) != len(self.nodes): | ||
raise ValueError( | ||
"List of new node names must be the same length as the current " | ||
"nodes." | ||
) | ||
name_map = {node: name for node, name in zip(self.nodes, name_map)} | ||
|
||
for old_name, new_name in name_map.items(): | ||
if type(old_name) == Node: | ||
old_name = old_name.name | ||
if type(old_name) == int: | ||
old_name = self.nodes[old_name].name | ||
|
||
if old_name not in self._node_name_map: | ||
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | ||
if new_name in self._node_name_map: | ||
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | ||
|
||
node = self._node_name_map[old_name] | ||
node.name = new_name | ||
self._node_name_map[new_name] = node | ||
del self._node_name_map[old_name] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add validation for empty input and string type.
While the existing implementation handles most edge cases well, consider adding these additional validations:
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]):
+ if not name_map:
+ raise ValueError("name_map cannot be empty")
+
if isinstance(name_map, list):
+ if not all(isinstance(name, str) for name in name_map):
+ raise ValueError("All new names must be strings")
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}
for old_name, new_name in name_map.items():
+ if not isinstance(new_name, str):
+ raise ValueError(f"New name must be a string, got {type(new_name)}")
if type(old_name) == Node:
old_name = old_name.name
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]): | |
"""Rename nodes in the skeleton. | |
Args: | |
name_map: A dictionary mapping old node names to new node names. Keys can be | |
specified as `Node` objects, integer indices, or string names. Values | |
must be specified as string names. | |
If a list of strings is provided of the same length as the current | |
nodes, the nodes will be renamed to the names in the list in order. | |
Notes: | |
This method should always be used when renaming nodes in the skeleton as it | |
handles updating the lookup caches necessary for indexing nodes by name. | |
After renaming, instances using this skeleton do NOT need to be updated as | |
the nodes are stored by reference in the skeleton. | |
Example: | |
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")]) | |
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | |
>>> skel.node_names | |
["X", "Y", "Z"] | |
>>> skel.rename_nodes(["a", "b", "c"]) | |
>>> skel.node_names | |
["a", "b", "c"] | |
Raises: | |
ValueError: If the new node names exist in the skeleton or if the old node | |
names are not found in the skeleton. | |
""" | |
if type(name_map) == list: | |
if len(name_map) != len(self.nodes): | |
raise ValueError( | |
"List of new node names must be the same length as the current " | |
"nodes." | |
) | |
name_map = {node: name for node, name in zip(self.nodes, name_map)} | |
for old_name, new_name in name_map.items(): | |
if type(old_name) == Node: | |
old_name = old_name.name | |
if type(old_name) == int: | |
old_name = self.nodes[old_name].name | |
if old_name not in self._node_name_map: | |
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | |
if new_name in self._node_name_map: | |
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | |
node = self._node_name_map[old_name] | |
node.name = new_name | |
self._node_name_map[new_name] = node | |
del self._node_name_map[old_name] | |
def rename_nodes(self, name_map: dict[str | int | Node, str] | list[str]): | |
"""Rename nodes in the skeleton. | |
Args: | |
name_map: A dictionary mapping old node names to new node names. Keys can be | |
specified as `Node` objects, integer indices, or string names. Values | |
must be specified as string names. | |
If a list of strings is provided of the same length as the current | |
nodes, the nodes will be renamed to the names in the list in order. | |
Notes: | |
This method should always be used when renaming nodes in the skeleton as it | |
handles updating the lookup caches necessary for indexing nodes by name. | |
After renaming, instances using this skeleton do NOT need to be updated as | |
the nodes are stored by reference in the skeleton. | |
Example: | |
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")]) | |
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | |
>>> skel.node_names | |
["X", "Y", "Z"] | |
>>> skel.rename_nodes(["a", "b", "c"]) | |
>>> skel.node_names | |
["a", "b", "c"] | |
Raises: | |
ValueError: If the new node names exist in the skeleton or if the old node | |
names are not found in the skeleton. | |
""" | |
if not name_map: | |
raise ValueError("name_map cannot be empty") | |
if isinstance(name_map, list): | |
if not all(isinstance(name, str) for name in name_map): | |
raise ValueError("All new names must be strings") | |
if len(name_map) != len(self.nodes): | |
raise ValueError( | |
"List of new node names must be the same length as the current " | |
"nodes." | |
) | |
name_map = {node: name for node, name in zip(self.nodes, name_map)} | |
for old_name, new_name in name_map.items(): | |
if not isinstance(new_name, str): | |
raise ValueError(f"New name must be a string, got {type(new_name)}") | |
if type(old_name) == Node: | |
old_name = old_name.name | |
if type(old_name) == int: | |
old_name = self.nodes[old_name].name | |
if old_name not in self._node_name_map: | |
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | |
if new_name in self._node_name_map: | |
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | |
node = self._node_name_map[old_name] | |
node.name = new_name | |
self._node_name_map[new_name] = node | |
del self._node_name_map[old_name] |
🧰 Tools
🪛 Ruff
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
…and add bulk node addition
…edge addition method
…r node removal and reordering
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Outside diff range and nitpick comments (5)
tests/model/test_instance.py (3)
180-190
: Consider enhancing error message verification.While the test effectively verifies the behavior during node removal, it could be more robust by checking the specific error message.
Consider updating the error check to verify the specific message:
- with pytest.raises(KeyError): + with pytest.raises(KeyError, match="Node 'Y' not found in skeleton"): inst.numpy() # .numpy() breaks without update
195-197
: Improve assertion readability.The multi-line assertion could be more readable using a separate variable.
Consider refactoring to:
- assert ( - list(inst.points.keys()) != skel.nodes - ) # but the points dict still has the old order + points_keys = list(inst.points.keys()) + assert points_keys != skel.nodes # but the points dict still has the old order
169-169
: Add docstring to test function.The test function is missing a docstring describing its purpose and the scenarios being tested.
Add a descriptive docstring:
def test_instance_update_skeleton(): + """Test Instance behavior when Skeleton is modified. + + Verifies: + 1. Node renaming doesn't require instance update + 2. Node removal requires update_skeleton() call + 3. Node reordering affects numpy() output and requires update + """tests/model/test_skeleton.py (1)
226-245
: Add test case for symmetry preservation during reordering.While the test cases cover node reordering and its impact on edges, consider adding a test case to verify that symmetry relationships are preserved after reordering nodes.
Example test case:
def test_reorder_nodes_preserves_symmetry(): skel = Skeleton(["A", "BL", "BR", "C"]) skel.add_symmetry("BL", "BR") original_symmetries = skel.symmetry_inds.copy() skel.reorder_nodes(["C", "BR", "A", "BL"]) assert skel.symmetry_inds != original_symmetries # indices should update assert skel.symmetry_names == [("BL", "BR")] # but relationships should remainsleap_io/model/instance.py (1)
311-328
: Enhance method documentation with type hints and examples.The documentation is clear but could be more comprehensive. Consider:
- Adding return type hint (-> None)
- Documenting any exceptions that might be raised
- Adding a usage example to illustrate the behavior
- def update_skeleton(self): + def update_skeleton(self) -> None: """Update the points dictionary to match the skeleton. Points associated with nodes that are no longer in the skeleton will be removed. Additionally, the keys of the points dictionary will be ordered to match the order of the nodes in the skeleton. Notes: This method is useful when the skeleton has been updated (e.g., nodes removed or reordered). However, it is recommended to use `Labels`-level methods (e.g., `Labels.remove_nodes()`) when manipulating the skeleton as these will automatically call this method on every instance. It is NOT necessary to call this method after renaming nodes. + + Example: + >>> instance = Instance(points=points, skeleton=skeleton) + >>> skeleton.remove_nodes(['node_to_remove']) + >>> instance.update_skeleton() # Points for removed node are dropped + + Returns: + None + + Raises: + AttributeError: If skeleton is not set """
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (4)
sleap_io/model/instance.py
(1 hunks)sleap_io/model/skeleton.py
(6 hunks)tests/model/test_instance.py
(1 hunks)tests/model/test_skeleton.py
(3 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py
264-264: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
271-271: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
273-273: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
280-280: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
282-282: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
284-284: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
302-302: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
305-305: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
345-345: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
362-362: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
399-399: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
444-444: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
453-453: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
455-455: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
🔇 Additional comments (11)
tests/model/test_instance.py (2)
169-179
: LGTM! Well-structured test setup and rename verification.
The test setup and node renaming verification is clear, comprehensive, and follows testing best practices by checking both direct node access and array representation.
169-199
: 🛠️ Refactor suggestion
Consider adding tests for edge cases.
The test provides good coverage of basic operations, but could be enhanced with additional scenarios:
- Multiple simultaneous operations (e.g., rename + remove)
- Edge case where all nodes are removed
- Invalid node names in operations
Let's check if these cases are covered elsewhere:
tests/model/test_skeleton.py (5)
107-108
: LGTM! Enhanced test coverage for node addition.
The additional assertions properly verify node presence, indexing, and multiple node addition scenarios.
Also applies to: 110-115, 122-124
145-147
: LGTM! Good coverage of multiple edge addition.
The test case properly verifies the add_edges method with multiple edges and confirms correct edge indices.
169-193
: Add tests for edge and symmetry preservation after renaming.
While the current test cases cover basic renaming functionality and error cases, they don't verify that edges and symmetries are preserved after renaming operations.
194-224
: LGTM! Thorough testing of node removal functionality.
The test cases comprehensively verify:
- Node index updates after removal
- Edge list maintenance
- Symmetry relationship updates
- Both single and multiple node removal scenarios
- Error handling for invalid cases
169-245
: LGTM! Comprehensive test suite for skeleton manipulation utilities.
The test suite provides good coverage of the new functionality with:
- Clear test organization
- Comprehensive error case handling
- Verification of data structure integrity
- Consistent testing style
Minor suggestions for additional test cases have been provided above, but the current coverage is already robust.
sleap_io/model/instance.py (1)
311-343
: Verify integration with Labels-level methods.
The documentation mentions that Labels-level methods should call this method automatically. Let's verify this integration to ensure proper skeleton synchronization.
sleap_io/model/skeleton.py (3)
Line range hint 74-110
: LGTM! Type improvements enhance code safety and readability.
The introduction of NodeOrIndex
type alias and the renaming of cache fields to more descriptive names (_name_to_node_cache
and _node_to_ind_cache
) improve code clarity. The callback method _nodes_on_setattr
ensures cache consistency.
412-467
: LGTM! Well-implemented node renaming with proper validation.
The node renaming implementation is thorough with:
- Comprehensive input validation
- Clear error messages
- Proper cache updates
- Excellent documentation including examples
🧰 Tools
🪛 Ruff
444-444: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
453-453: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
455-455: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
552-589
: LGTM! Well-implemented node reordering with clear documentation.
The node reordering implementation is solid with:
- Proper validation of the new order
- Clear warnings about instance updates
- Comprehensive documentation about potential side effects
# Create a new dictionary to hold the updated points | ||
new_points = {} | ||
|
||
# Iterate over the nodes in the skeleton | ||
for node in self.skeleton.nodes: | ||
# Get the point associated with the node | ||
point = self.points.get(node, None) | ||
|
||
# If the point is not None, add it to the new dictionary | ||
if point is not None: | ||
new_points[node] = point | ||
|
||
# Update the points dictionary | ||
self.points = new_points |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider optimizing for the no-change case.
The current implementation always creates a new dictionary, even when no changes are needed. Consider adding a fast path for when the points dictionary already matches the skeleton.
def update_skeleton(self) -> None:
+ # Fast path: check if update is needed
+ if (set(self.points.keys()) == set(self.skeleton.nodes) and
+ list(self.points.keys()) == list(self.skeleton.nodes)):
+ return
+
# Create a new dictionary to hold the updated points
new_points = {}
# Iterate over the nodes in the skeleton
for node in self.skeleton.nodes:
# Get the point associated with the node
point = self.points.get(node, None)
# If the point is not None, add it to the new dictionary
if point is not None:
new_points[node] = point
# Update the points dictionary
self.points = new_points
Committable suggestion skipped: line range outside the PR's diff.
def rebuild_cache(self, nodes: list[Node] | None = None): | ||
"""Rebuild the node name/index to `Node` map caches. | ||
|
||
Args: | ||
nodes: A list of `Node` objects to update the cache with. If not provided, | ||
the cache will be updated with the current nodes in the skeleton. If | ||
nodes are provided, the cache will be updated with the provided nodes, | ||
but the current nodes in the skeleton will not be updated. Default is | ||
`None`. | ||
|
||
Notes: | ||
This function should be called when nodes or node list is mutated to update | ||
the lookup caches for indexing nodes by name or `Node` object. | ||
|
||
This is done automatically when nodes are added or removed from the skeleton | ||
using the convenience methods in this class. | ||
|
||
This method only needs to be used when manually mutating nodes or the node | ||
list directly. | ||
""" | ||
if nodes is None: | ||
nodes = self.nodes | ||
self._name_to_node_cache = {node.name: node for node in nodes} | ||
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding error handling for duplicate node names.
While the cache rebuilding logic is solid, it might silently overwrite entries if there are duplicate node names.
def rebuild_cache(self, nodes: list[Node] | None = None):
if nodes is None:
nodes = self.nodes
+ # Check for duplicate names
+ name_counts = {}
+ for node in nodes:
+ name_counts[node.name] = name_counts.get(node.name, 0) + 1
+ if name_counts[node.name] > 1:
+ raise ValueError(f"Duplicate node name found: {node.name}")
self._name_to_node_cache = {node.name: node for node in nodes}
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def rebuild_cache(self, nodes: list[Node] | None = None): | |
"""Rebuild the node name/index to `Node` map caches. | |
Args: | |
nodes: A list of `Node` objects to update the cache with. If not provided, | |
the cache will be updated with the current nodes in the skeleton. If | |
nodes are provided, the cache will be updated with the provided nodes, | |
but the current nodes in the skeleton will not be updated. Default is | |
`None`. | |
Notes: | |
This function should be called when nodes or node list is mutated to update | |
the lookup caches for indexing nodes by name or `Node` object. | |
This is done automatically when nodes are added or removed from the skeleton | |
using the convenience methods in this class. | |
This method only needs to be used when manually mutating nodes or the node | |
list directly. | |
""" | |
if nodes is None: | |
nodes = self.nodes | |
self._name_to_node_cache = {node.name: node for node in nodes} | |
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)} | |
def rebuild_cache(self, nodes: list[Node] | None = None): | |
"""Rebuild the node name/index to `Node` map caches. | |
Args: | |
nodes: A list of `Node` objects to update the cache with. If not provided, | |
the cache will be updated with the current nodes in the skeleton. If | |
nodes are provided, the cache will be updated with the provided nodes, | |
but the current nodes in the skeleton will not be updated. Default is | |
`None`. | |
Notes: | |
This function should be called when nodes or node list is mutated to update | |
the lookup caches for indexing nodes by name or `Node` object. | |
This is done automatically when nodes are added or removed from the skeleton | |
using the convenience methods in this class. | |
This method only needs to be used when manually mutating nodes or the node | |
list directly. | |
""" | |
if nodes is None: | |
nodes = self.nodes | |
# Check for duplicate names | |
name_counts = {} | |
for node in nodes: | |
name_counts[node.name] = name_counts.get(node.name, 0) + 1 | |
if name_counts[node.name] > 1: | |
raise ValueError(f"Duplicate node name found: {node.name}") | |
self._name_to_node_cache = {node.name: node for node in nodes} | |
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)} |
sleap_io/model/skeleton.py
Outdated
def remove_nodes(self, nodes: list[NodeOrIndex]): | ||
"""Remove nodes from the skeleton. | ||
|
||
Args: | ||
nodes: A list of node names, indices, or `Node` objects to remove. | ||
|
||
Notes: | ||
This method should always be used when removing nodes from the skeleton as | ||
it handles updating the lookup caches necessary for indexing nodes by name. | ||
|
||
Any edges and symmetries that are connected to the removed nodes will also | ||
be removed. | ||
|
||
Warning: | ||
**This method does NOT update instances** that use this skeleton to reflect | ||
changes. | ||
|
||
It is recommended to use the `Labels.remove_nodes()` method which will | ||
update all contained to reflect the changes made to the skeleton. | ||
|
||
To manually update instances after this method is called, call | ||
`instance.update_nodes()` on each instance that uses this skeleton. | ||
""" | ||
# Standardize input and make a pre-mutation copy before keys are changed. | ||
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | ||
|
||
# Remove nodes from the skeleton. | ||
for node in rm_node_objs: | ||
self.nodes.remove(node) | ||
del self._name_to_node_cache[node.name] | ||
|
||
# Remove edges connected to the removed nodes. | ||
self.edges = [ | ||
edge | ||
for edge in self.edges | ||
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | ||
] | ||
|
||
# Remove symmetries connected to the removed nodes. | ||
self.symmetries = [ | ||
symmetry | ||
for symmetry in self.symmetries | ||
if symmetry.nodes.isdisjoint(rm_node_objs) | ||
] | ||
|
||
# Update node index map. | ||
self.rebuild_cache() | ||
|
||
def remove_node(self, node: NodeOrIndex): | ||
"""Remove a single node from the skeleton. | ||
|
||
Args: | ||
node: The node to remove. Can be specified as a string name, integer index, | ||
or `Node` object. | ||
|
||
Notes: | ||
This method should always be used when removing nodes from the skeleton as | ||
it handles updating the lookup caches necessary for indexing nodes by name. | ||
|
||
Any edges and symmetries that are connected to the removed node will also be | ||
removed. | ||
|
||
Warning: | ||
**This method does NOT update instances** that use this skeleton to reflect | ||
changes. | ||
|
||
It is recommended to use the `Labels.remove_nodes()` method which will | ||
update all contained instances to reflect the changes made to the skeleton. | ||
|
||
To manually update instances after this method is called, call | ||
`Instance.update_skeleton()` on each instance that uses this skeleton. | ||
""" | ||
self.remove_nodes([node]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider making node removal atomic.
While the node removal implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state.
def remove_nodes(self, nodes: list[NodeOrIndex]):
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
+
+ # Create copies of current state
+ new_nodes = self.nodes.copy()
+ new_edges = self.edges.copy()
+ new_symmetries = self.symmetries.copy()
+ new_name_cache = self._name_to_node_cache.copy()
+
+ try:
+ # Remove nodes
+ for node in rm_node_objs:
+ new_nodes.remove(node)
+ del new_name_cache[node.name]
+
+ # Update edges and symmetries
+ new_edges = [
+ edge for edge in new_edges
+ if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
+ ]
+ new_symmetries = [
+ symmetry for symmetry in new_symmetries
+ if symmetry.nodes.isdisjoint(rm_node_objs)
+ ]
+
+ # Commit changes
+ self.nodes = new_nodes
+ self.edges = new_edges
+ self.symmetries = new_symmetries
+ self._name_to_node_cache = new_name_cache
+ self.rebuild_cache()
+ except Exception as e:
+ # Rollback on error
+ raise ValueError(f"Failed to remove nodes: {str(e)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def remove_nodes(self, nodes: list[NodeOrIndex]): | |
"""Remove nodes from the skeleton. | |
Args: | |
nodes: A list of node names, indices, or `Node` objects to remove. | |
Notes: | |
This method should always be used when removing nodes from the skeleton as | |
it handles updating the lookup caches necessary for indexing nodes by name. | |
Any edges and symmetries that are connected to the removed nodes will also | |
be removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`instance.update_nodes()` on each instance that uses this skeleton. | |
""" | |
# Standardize input and make a pre-mutation copy before keys are changed. | |
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | |
# Remove nodes from the skeleton. | |
for node in rm_node_objs: | |
self.nodes.remove(node) | |
del self._name_to_node_cache[node.name] | |
# Remove edges connected to the removed nodes. | |
self.edges = [ | |
edge | |
for edge in self.edges | |
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | |
] | |
# Remove symmetries connected to the removed nodes. | |
self.symmetries = [ | |
symmetry | |
for symmetry in self.symmetries | |
if symmetry.nodes.isdisjoint(rm_node_objs) | |
] | |
# Update node index map. | |
self.rebuild_cache() | |
def remove_node(self, node: NodeOrIndex): | |
"""Remove a single node from the skeleton. | |
Args: | |
node: The node to remove. Can be specified as a string name, integer index, | |
or `Node` object. | |
Notes: | |
This method should always be used when removing nodes from the skeleton as | |
it handles updating the lookup caches necessary for indexing nodes by name. | |
Any edges and symmetries that are connected to the removed node will also be | |
removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained instances to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`Instance.update_skeleton()` on each instance that uses this skeleton. | |
""" | |
self.remove_nodes([node]) | |
def remove_nodes(self, nodes: list[NodeOrIndex]): | |
"""Remove nodes from the skeleton. | |
Args: | |
nodes: A list of node names, indices, or `Node` objects to remove. | |
Notes: | |
This method should always be used when removing nodes from the skeleton as | |
it handles updating the lookup caches necessary for indexing nodes by name. | |
Any edges and symmetries that are connected to the removed nodes will also | |
be removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`instance.update_nodes()` on each instance that uses this skeleton. | |
""" | |
# Standardize input and make a pre-mutation copy before keys are changed. | |
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | |
# Create copies of current state | |
new_nodes = self.nodes.copy() | |
new_edges = self.edges.copy() | |
new_symmetries = self.symmetries.copy() | |
new_name_cache = self._name_to_node_cache.copy() | |
try: | |
# Remove nodes | |
for node in rm_node_objs: | |
new_nodes.remove(node) | |
del new_name_cache[node.name] | |
# Update edges and symmetries | |
new_edges = [ | |
edge for edge in new_edges | |
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | |
] | |
new_symmetries = [ | |
symmetry for symmetry in new_symmetries | |
if symmetry.nodes.isdisjoint(rm_node_objs) | |
] | |
# Commit changes | |
self.nodes = new_nodes | |
self.edges = new_edges | |
self.symmetries = new_symmetries | |
self._name_to_node_cache = new_name_cache | |
self.rebuild_cache() | |
except Exception as e: | |
# Rollback on error | |
raise ValueError(f"Failed to remove nodes: {str(e)}") | |
def remove_node(self, node: NodeOrIndex): | |
"""Remove a single node from the skeleton. | |
Args: | |
node: The node to remove. Can be specified as a string name, integer index, | |
or `Node` object. | |
Notes: | |
This method should always be used when removing nodes from the skeleton as | |
it handles updating the lookup caches necessary for indexing nodes by name. | |
Any edges and symmetries that are connected to the removed node will also be | |
removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained instances to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`Instance.update_skeleton()` on each instance that uses this skeleton. | |
""" | |
self.remove_nodes([node]) |
def __getitem__(self, idx: NodeOrIndex) -> Node: | ||
"""Return a `Node` when indexing by name or integer.""" | ||
if type(idx) == int: | ||
return self.nodes[idx] | ||
elif type(idx) == str: | ||
return self._node_name_map[idx] | ||
return self._name_to_node_cache[idx] | ||
else: | ||
raise IndexError(f"Invalid indexing argument for skeleton: {idx}") | ||
|
||
def __contains__(self, node: NodeOrIndex) -> bool: | ||
"""Check if a node is in the skeleton.""" | ||
if type(node) == str: | ||
return node in self._name_to_node_cache | ||
elif type(node) == Node: | ||
return node in self.nodes | ||
elif type(node) == int: | ||
return 0 <= node < len(self.nodes) | ||
else: | ||
raise ValueError(f"Invalid node type for skeleton: {node}") | ||
|
||
def add_node(self, node: Node | str): | ||
"""Add a `Node` to the skeleton. | ||
|
||
Args: | ||
node: A `Node` object or a string name to create a new node. | ||
|
||
Raises: | ||
ValueError: If the node already exists in the skeleton or if the node is | ||
not specified as a `Node` or string. | ||
""" | ||
node_name = node.name if type(node) == Node else node | ||
if node_name in self._node_name_map: | ||
raise ValueError(f"Node '{node_name}' already exists in the skeleton.") | ||
if node in self: | ||
raise ValueError(f"Node '{node}' already exists in the skeleton.") | ||
|
||
if type(node) == str: | ||
node = Node(node) | ||
if node not in self.nodes: | ||
self.nodes.append(node) | ||
self._update_node_map(None, self.nodes) | ||
|
||
def add_edge(self, src: Edge | Node | str = None, dst: Node | str = None): | ||
if type(node) != Node: | ||
raise ValueError(f"Invalid node type: {node} ({type(node)})") | ||
|
||
self.nodes.append(node) | ||
|
||
# Atomic update of the cache. | ||
self._name_to_node_cache[node.name] = node | ||
self._node_to_ind_cache[node] = len(self.nodes) - 1 | ||
|
||
def add_nodes(self, nodes: list[Node | str]): | ||
"""Add multiple `Node`s to the skeleton. | ||
|
||
Args: | ||
nodes: A list of `Node` objects or string names to create new nodes. | ||
""" | ||
for node in nodes: | ||
self.add_node(node) | ||
|
||
def require_node(self, node: NodeOrIndex, add_missing: bool = True) -> Node: | ||
"""Return a `Node` object, handling indexing and adding missing nodes. | ||
|
||
Args: | ||
node: A `Node` object, name or index. | ||
add_missing: If `True`, missing nodes will be added to the skeleton. If | ||
`False`, an error will be raised if the node is not found. Default is | ||
`True`. | ||
|
||
Returns: | ||
The `Node` object. | ||
|
||
Raises: | ||
IndexError: If the node is not found in the skeleton and `add_missing` is | ||
`False`. | ||
""" | ||
if node not in self: | ||
if add_missing: | ||
self.add_node(node) | ||
else: | ||
raise IndexError(f"Node '{node}' not found in the skeleton.") | ||
|
||
if type(node) == Node: | ||
return node | ||
|
||
return self[node] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace type() comparisons with isinstance() checks.
Following Python best practices, use isinstance()
for type checking as it handles inheritance properly.
- if type(idx) == int:
+ if isinstance(idx, int):
return self.nodes[idx]
- elif type(idx) == str:
+ elif isinstance(idx, str):
return self._name_to_node_cache[idx]
- if type(node) == str:
+ if isinstance(node, str):
return node in self._name_to_node_cache
- elif type(node) == Node:
+ elif isinstance(node, Node):
return node in self.nodes
- elif type(node) == int:
+ elif isinstance(node, int):
return 0 <= node < len(self.nodes)
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff
271-271: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
273-273: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
280-280: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
282-282: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
284-284: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
302-302: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
305-305: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
345-345: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
🧹 Outside diff range and nitpick comments (7)
tests/model/test_instance.py (2)
169-200
: LGTM! Well-structured test cases for skeleton updates.The test cases effectively cover the three main skeleton manipulation scenarios (renaming, removal, reordering) with good assertions for both success and error cases.
Consider enhancing the reordering test with explicit assertions for the numpy array values before and after update to make the expected behavior more clear. For example:
# Reorder nodes old_order = inst.numpy() # capture original order skel.reorder_nodes(["Z", "X"]) assert_equal(inst.numpy(), [[2, 2], [0, 0]]) # verify immediate reorder effect inst.update_skeleton() assert_equal(inst.numpy(), old_order) # verify order is preserved after update
202-232
: LGTM! Comprehensive test coverage for skeleton replacement scenarios.The test cases effectively cover full replacement, partial replacement, and the fast path optimization with good assertions for data integrity and cleanup.
Consider adding a performance comparison for the fast path test to validate the optimization:
# Fast path performance comparison import time # Standard path timing start = time.perf_counter() inst.replace_skeleton(new_skel, node_map={"A": "X", "B": "Y", "C": "Z"}) standard_time = time.perf_counter() - start # Fast path timing start = time.perf_counter() inst.replace_skeleton(new_skel, rev_node_map=rev_node_map) fast_time = time.perf_counter() - start # Verify performance improvement assert fast_time < standard_time, "Fast path should be faster than standard path"sleap_io/model/instance.py (1)
343-394
: Consider enhancing documentation and input validation.The implementation is solid, but could benefit from:
- Input validation for
new_skeleton
to ensure it's not None- Adding an Examples section to the docstring showing common usage patterns
Add type validation and enhance docstring:
def replace_skeleton( self, new_skeleton: Skeleton, node_map: dict[NodeOrIndex, NodeOrIndex] | None = None, rev_node_map: dict[NodeOrIndex, NodeOrIndex] | None = None, ): """Replace the skeleton associated with the instance. The points dictionary will be updated to match the new skeleton. Args: new_skeleton: The new `Skeleton` to associate with the instance. node_map: Dictionary mapping nodes in the old skeleton to nodes in the new skeleton. Keys and values can be specified as `Node` objects, integer indices, or string names. If not provided, only nodes with identical names will be mapped. Points associated with unmapped nodes will be removed. rev_node_map: Dictionary mapping nodes in the new skeleton to nodes in the old skeleton. This is used internally when calling from `Labels.replace_skeleton()` as it is more efficient to compute this mapping once and pass it to all instances. No validation is done on this mapping, so nodes are expected to be `Node` objects. + + Examples: + >>> # Replace with new skeleton, automatically mapping nodes by name + >>> instance.replace_skeleton(new_skeleton) + >>> + >>> # Replace with explicit node mapping + >>> instance.replace_skeleton( + ... new_skeleton, + ... node_map={"head": "nose", "tail": "tail_tip"} + ... ) """ + if new_skeleton is None: + raise ValueError("new_skeleton cannot be None") + if rev_node_map is None:tests/model/test_labels.py (2)
670-682
: Consider adding edge cases to the node renaming tests.While the test covers the basic functionality well, consider adding tests for:
- Attempting to rename non-existent nodes
- Renaming to already existing node names
- Empty or invalid node names
703-719
: Consider adding validation tests for node reordering.While the test covers the basic functionality well, consider adding tests for:
- Reordering with missing nodes
- Reordering with duplicate nodes
- Reordering with invalid node names
sleap_io/model/labels.py (2)
468-585
: LGTM: Well-implemented node manipulation methods with thorough documentation.The methods properly handle both skeleton and instance updates, with comprehensive error handling. However, there's an opportunity to reduce code duplication.
Consider extracting the skeleton validation logic into a private helper method to reduce duplication. Here's a suggested implementation:
+ def _validate_and_get_skeleton(self, skeleton: Skeleton | None = None) -> Skeleton: + """Validate and return the skeleton to operate on. + + Args: + skeleton: Optional skeleton to validate. If None, uses the default skeleton. + + Returns: + The validated skeleton to use. + + Raises: + ValueError: If there is more than one skeleton but none was specified. + """ + if skeleton is None: + if len(self.skeletons) != 1: + raise ValueError( + "Skeleton must be specified when there is more than one skeleton in " + "the labels." + ) + return self.skeleton + return skeleton def rename_nodes( self, name_map: dict[NodeOrIndex, str] | list[str], skeleton: Skeleton | None = None, ): - if skeleton is None: - if len(self.skeletons) != 1: - raise ValueError( - "Skeleton must be specified when there is more than one skeleton in " - "the labels." - ) - skeleton = self.skeleton + skeleton = self._validate_and_get_skeleton(skeleton) # ... rest of the methodThis refactor would:
- Reduce code duplication
- Make the validation logic more maintainable
- Keep the same functionality and error handling
586-647
: LGTM: Comprehensive skeleton replacement implementation with proper node mapping.The method handles both explicit and automatic node mapping well, with appropriate warnings about data loss. However, there's an opportunity for performance optimization.
Consider optimizing the instance update loop for better performance with large datasets:
- for inst in self.instances: - if inst.skeleton == old_skeleton: - inst.replace_skeleton(new_skeleton, rev_node_map=rev_node_map) + # Filter instances first to avoid checking skeleton equality for every instance + affected_instances = (inst for inst in self.instances if inst.skeleton == old_skeleton) + for inst in affected_instances: + inst.replace_skeleton(new_skeleton, rev_node_map=rev_node_map)This optimization:
- Reduces the number of skeleton equality checks
- Is more memory efficient by using a generator expression
- Maintains the same functionality
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (5)
sleap_io/model/instance.py
(2 hunks)sleap_io/model/labels.py
(2 hunks)sleap_io/model/skeleton.py
(6 hunks)tests/model/test_instance.py
(1 hunks)tests/model/test_labels.py
(2 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py
264-264: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
271-271: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
273-273: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
280-280: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
282-282: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
284-284: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
302-302: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
305-305: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
345-345: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
362-362: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
373-373: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
399-399: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
444-444: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
453-453: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
455-455: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
tests/model/test_labels.py
3-3: numpy.testing.assert_equal
imported but unused
Remove unused import: numpy.testing.assert_equal
(F401)
🔇 Additional comments (11)
sleap_io/model/instance.py (2)
14-14
: LGTM!
The import is correctly placed and necessary for type hints in the new replace_skeleton
method.
312-341
: Implementation is correct but can be optimized.
The method correctly handles point updates when the skeleton changes, but could benefit from optimization.
The previous review comment about optimizing for the no-change case is still valid. Consider adding a fast path check before creating a new dictionary.
sleap_io/model/skeleton.py (4)
74-75
: LGTM! Good addition of type alias.
The NodeOrIndex
type alias improves code readability and type safety by clearly defining acceptable node reference types.
323-349
: LGTM! Well-implemented utility method.
The require_node
method provides a robust way to handle node references with good error handling and flexibility.
🧰 Tools
🪛 Ruff
345-345: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
552-589
: LGTM! Well-implemented reordering functionality.
The implementation includes proper validation and clear documentation about potential side effects on instances.
161-185
: 🛠️ Refactor suggestion
Consider adding duplicate node name detection.
While the cache rebuilding logic is solid, it might silently overwrite entries if there are duplicate node names.
def rebuild_cache(self, nodes: list[Node] | None = None):
if nodes is None:
nodes = self.nodes
+ # Check for duplicate names
+ name_counts = {}
+ for node in nodes:
+ name_counts[node.name] = name_counts.get(node.name, 0) + 1
+ if name_counts[node.name] > 1:
+ raise ValueError(f"Duplicate node name found: {node.name}")
self._name_to_node_cache = {node.name: node for node in nodes}
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}
Likely invalid or redundant comment.
tests/model/test_labels.py (3)
638-668
: LGTM! Well-structured test for the instances property.
The test thoroughly covers both single and multiple instance scenarios, ensuring proper instance counting across frames.
684-701
: LGTM! Comprehensive test coverage for node removal.
The test effectively verifies:
- Node removal from skeleton
- Impact on instance data
- Error handling for invalid cases
721-749
: LGTM! Excellent test coverage for skeleton replacement.
The test comprehensively covers all skeleton replacement scenarios:
- Full node mapping
- Partial (inferred) mapping
- No mapping case
- Proper handling of instance data updates
sleap_io/model/labels.py (2)
16-16
: LGTM: Import changes are appropriate.
The new imports support the added functionality for skeleton manipulation and type hints.
Also applies to: 24-24, 26-26
463-466
: LGTM: Efficient implementation of instances property.
The use of a generator expression instead of a list comprehension is memory-efficient, especially for large datasets with many instances.
def test_instance_update_skeleton(): | ||
skel = Skeleton(["A", "B", "C"]) | ||
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=skel) | ||
|
||
# No need to update on rename | ||
skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | ||
assert inst["X"].x == 0 | ||
assert inst["Y"].x == 1 | ||
assert inst["Z"].x == 2 | ||
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]]) | ||
|
||
# Remove a node from the skeleton | ||
Y = skel["Y"] | ||
skel.remove_node("Y") | ||
assert Y not in skel | ||
|
||
with pytest.raises(KeyError): | ||
inst.numpy() # .numpy() breaks without update | ||
assert Y in inst.points # and the points dict still has the old key | ||
inst.update_skeleton() | ||
assert Y not in inst.points # after update, the old key is gone | ||
assert_equal(inst.numpy(), [[0, 0], [2, 2]]) | ||
|
||
# Reorder nodes | ||
skel.reorder_nodes(["Z", "X"]) | ||
assert_equal(inst.numpy(), [[2, 2], [0, 0]]) # .numpy() works without update | ||
assert ( | ||
list(inst.points.keys()) != skel.nodes | ||
) # but the points dict still has the old order | ||
inst.update_skeleton() | ||
assert list(inst.points.keys()) == skel.nodes # after update, the order is correct | ||
|
||
|
||
def test_instance_replace_skeleton(): | ||
# Full replacement | ||
old_skel = Skeleton(["A", "B", "C"]) | ||
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel) | ||
new_skel = Skeleton(["X", "Y", "Z"]) | ||
inst.replace_skeleton(new_skel, node_map={"A": "X", "B": "Y", "C": "Z"}) | ||
assert inst.skeleton == new_skel | ||
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]]) | ||
assert list(inst.points.keys()) == new_skel.nodes | ||
|
||
# Partial replacement | ||
old_skel = Skeleton(["A", "B", "C"]) | ||
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel) | ||
new_skel = Skeleton(["X", "C", "Y"]) | ||
inst.replace_skeleton(new_skel) | ||
assert inst.skeleton == new_skel | ||
assert_equal(inst.numpy(), [[np.nan, np.nan], [2, 2], [np.nan, np.nan]]) | ||
assert new_skel["C"] in inst.points | ||
assert old_skel["A"] not in inst.points | ||
assert old_skel["C"] not in inst.points | ||
|
||
# Fast path with reverse node map | ||
old_skel = Skeleton(["A", "B", "C"]) | ||
inst = Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=old_skel) | ||
new_skel = Skeleton(["X", "Y", "Z"]) | ||
rev_node_map = { | ||
new_node: old_node for new_node, old_node in zip(new_skel.nodes, old_skel.nodes) | ||
} | ||
inst.replace_skeleton(new_skel, rev_node_map=rev_node_map) | ||
assert inst.skeleton == new_skel | ||
assert_equal(inst.numpy(), [[0, 0], [1, 1], [2, 2]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider using pytest fixtures to reduce setup duplication.
The new test functions share similar setup code for creating skeletons and instances. Consider extracting these into fixtures to improve test isolation and reduce duplication:
@pytest.fixture
def basic_skeleton():
return Skeleton(["A", "B", "C"])
@pytest.fixture
def basic_instance(basic_skeleton):
return Instance.from_numpy([[0, 0], [1, 1], [2, 2]], skeleton=basic_skeleton)
def test_instance_update_skeleton(basic_skeleton, basic_instance):
# Use fixtures instead of creating new objects
skel = basic_skeleton
inst = basic_instance
...
def test_instance_replace_skeleton(basic_skeleton, basic_instance):
# Use fixtures here as well
old_skel = basic_skeleton
inst = basic_instance
...
def _nodes_on_setattr(self, attr, new_nodes): | ||
"""Callback to update caches when nodes are set.""" | ||
self.rebuild_cache(nodes=new_nodes) | ||
return new_nodes | ||
|
||
nodes: list[Node] = field(factory=list, on_setattr=_update_node_map) | ||
nodes: list[Node] = field( | ||
factory=list, | ||
on_setattr=_nodes_on_setattr, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider making the callback more robust.
The callback mechanism for node updates could benefit from additional validation.
def _nodes_on_setattr(self, attr, new_nodes):
"""Callback to update caches when nodes are set."""
+ if not isinstance(new_nodes, list):
+ raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
+ if not all(isinstance(node, (Node, str)) for node in new_nodes):
+ raise TypeError("All nodes must be Node objects or strings")
self.rebuild_cache(nodes=new_nodes)
return new_nodes
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _nodes_on_setattr(self, attr, new_nodes): | |
"""Callback to update caches when nodes are set.""" | |
self.rebuild_cache(nodes=new_nodes) | |
return new_nodes | |
nodes: list[Node] = field(factory=list, on_setattr=_update_node_map) | |
nodes: list[Node] = field( | |
factory=list, | |
on_setattr=_nodes_on_setattr, | |
) | |
def _nodes_on_setattr(self, attr, new_nodes): | |
"""Callback to update caches when nodes are set.""" | |
if not isinstance(new_nodes, list): | |
raise TypeError(f"Expected list of nodes, got {type(new_nodes)}") | |
if not all(isinstance(node, (Node, str)) for node in new_nodes): | |
raise TypeError("All nodes must be Node objects or strings") | |
self.rebuild_cache(nodes=new_nodes) | |
return new_nodes | |
nodes: list[Node] = field( | |
factory=list, | |
on_setattr=_nodes_on_setattr, | |
) |
@@ -1,6 +1,6 @@ | |||
"""Test methods and functions in the sleap_io.model.labels file.""" | |||
|
|||
from numpy.testing import assert_equal | |||
from numpy.testing import assert_equal, assert_allclose |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
The assert_equal
import from numpy.testing
is not used in the code.
-from numpy.testing import assert_equal, assert_allclose
+from numpy.testing import assert_allclose
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from numpy.testing import assert_equal, assert_allclose | |
from numpy.testing import assert_allclose |
🧰 Tools
🪛 Ruff
3-3: numpy.testing.assert_equal
imported but unused
Remove unused import: numpy.testing.assert_equal
(F401)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (1)
sleap_io/model/skeleton.py
(6 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py
10-10: typing
imported but unused
Remove unused import: typing
(F401)
266-266: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
273-273: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
275-275: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
282-282: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
284-284: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
286-286: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
304-304: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
307-307: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
347-347: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
401-401: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
446-446: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
455-455: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
457-457: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
🔇 Additional comments (8)
sleap_io/model/skeleton.py (8)
74-77
: LGTM!
The type alias NodeOrIndex
is defined correctly using the |
operator for Python 3.10+. The commented-out alternatives for Python 3.12+ and older versions are also provided for reference.
Line range hint 194-207
: LGTM!
The properties edge_inds
and symmetry_inds
have been updated to use tuple[int, int]
instead of Tuple[int, int]
, reflecting a shift to a more modern type hinting style.
301-315
: LGTM!
The add_node
method correctly checks if the node already exists in the skeleton and raises a ValueError
if it does. It also handles the case where the input is a string by creating a new Node
object. The internal caches are updated atomically after the node is added.
🧰 Tools
🪛 Ruff
304-304: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
307-307: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
316-324
: LGTM!
The add_nodes
method is a convenience method that calls add_node
for each node in the input list. This ensures that the validation and cache updates are handled correctly for each node.
325-351
: LGTM!
The require_node
method is a useful utility that ensures a Node
object is returned for various input types (Node, name, or index). It also handles adding missing nodes to the skeleton if add_missing
is True
, which is the default behavior.
🧰 Tools
🪛 Ruff
347-347: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
554-591
: LGTM!
The reorder_nodes
method correctly updates the order of nodes in the skeleton. It raises a ValueError
if the new order is not the same length as the current nodes. The method also handles updating the internal caches after reordering.
The warning about instances not being automatically updated is valid and the recommendation to use Labels.reorder_nodes()
or manually update instances is appropriate.
99-107
: 🛠️ Refactor suggestion
Consider making the callback more robust.
The callback mechanism for node updates could benefit from additional validation.
def _nodes_on_setattr(self, attr, new_nodes):
"""Callback to update caches when nodes are set."""
+ if not isinstance(new_nodes, list):
+ raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
+ if not all(isinstance(node, Node) for node in new_nodes):
+ raise TypeError("All nodes must be Node objects")
self.rebuild_cache(nodes=new_nodes)
return new_nodes
Likely invalid or redundant comment.
163-187
: 🛠️ Refactor suggestion
Consider adding error handling for duplicate node names.
While the cache rebuilding logic is solid, it might silently overwrite entries if there are duplicate node names.
def rebuild_cache(self, nodes: list[Node] | None = None):
if nodes is None:
nodes = self.nodes
+ # Check for duplicate names
+ name_counts = {}
+ for node in nodes:
+ name_counts[node.name] = name_counts.get(node.name, 0) + 1
+ if name_counts[node.name] > 1:
+ raise ValueError(f"Duplicate node name found: {node.name}")
self._name_to_node_cache = {node.name: node for node in nodes}
self._node_to_ind_cache = {node: i for i, node in enumerate(nodes)}
Likely invalid or redundant comment.
@@ -7,7 +7,7 @@ | |||
|
|||
from __future__ import annotations | |||
from attrs import define, field | |||
from typing import Optional, Tuple, Union | |||
import typing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused import.
The typing
module is imported but not used in the code.
-import typing
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
import typing |
🧰 Tools
🪛 Ruff
10-10: typing
imported but unused
Remove unused import: typing
(F401)
def remove_nodes(self, nodes: list[NodeOrIndex]): | ||
"""Remove nodes from the skeleton. | ||
|
||
Args: | ||
nodes: A list of node names, indices, or `Node` objects to remove. | ||
|
||
Notes: | ||
This method handles updating the lookup caches necessary for indexing nodes | ||
by name. | ||
|
||
Any edges and symmetries that are connected to the removed nodes will also | ||
be removed. | ||
|
||
Warning: | ||
**This method does NOT update instances** that use this skeleton to reflect | ||
changes. | ||
|
||
It is recommended to use the `Labels.remove_nodes()` method which will | ||
update all contained to reflect the changes made to the skeleton. | ||
|
||
To manually update instances after this method is called, call | ||
`instance.update_nodes()` on each instance that uses this skeleton. | ||
""" | ||
# Standardize input and make a pre-mutation copy before keys are changed. | ||
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | ||
|
||
# Remove nodes from the skeleton. | ||
for node in rm_node_objs: | ||
self.nodes.remove(node) | ||
del self._name_to_node_cache[node.name] | ||
|
||
# Remove edges connected to the removed nodes. | ||
self.edges = [ | ||
edge | ||
for edge in self.edges | ||
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | ||
] | ||
|
||
# Remove symmetries connected to the removed nodes. | ||
self.symmetries = [ | ||
symmetry | ||
for symmetry in self.symmetries | ||
if symmetry.nodes.isdisjoint(rm_node_objs) | ||
] | ||
|
||
# Update node index map. | ||
self.rebuild_cache() | ||
|
||
def remove_node(self, node: NodeOrIndex): | ||
"""Remove a single node from the skeleton. | ||
|
||
Args: | ||
node: The node to remove. Can be specified as a string name, integer index, | ||
or `Node` object. | ||
|
||
Notes: | ||
This method handles updating the lookup caches necessary for indexing nodes | ||
by name. | ||
|
||
Any edges and symmetries that are connected to the removed node will also be | ||
removed. | ||
|
||
Warning: | ||
**This method does NOT update instances** that use this skeleton to reflect | ||
changes. | ||
|
||
It is recommended to use the `Labels.remove_nodes()` method which will | ||
update all contained instances to reflect the changes made to the skeleton. | ||
|
||
To manually update instances after this method is called, call | ||
`Instance.update_skeleton()` on each instance that uses this skeleton. | ||
""" | ||
self.remove_nodes([node]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider making node removal atomic.
While the node removal implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state.
def remove_nodes(self, nodes: list[NodeOrIndex]):
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
+
+ # Create copies of current state
+ new_nodes = self.nodes.copy()
+ new_edges = self.edges.copy()
+ new_symmetries = self.symmetries.copy()
+ new_name_cache = self._name_to_node_cache.copy()
+
+ try:
+ # Remove nodes
+ for node in rm_node_objs:
+ new_nodes.remove(node)
+ del new_name_cache[node.name]
+
+ # Update edges and symmetries
+ new_edges = [
+ edge for edge in new_edges
+ if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
+ ]
+ new_symmetries = [
+ symmetry for symmetry in new_symmetries
+ if symmetry.nodes.isdisjoint(rm_node_objs)
+ ]
+
+ # Commit changes
+ self.nodes = new_nodes
+ self.edges = new_edges
+ self.symmetries = new_symmetries
+ self._name_to_node_cache = new_name_cache
+ self.rebuild_cache()
+ except Exception as e:
+ # Rollback on error
+ raise ValueError(f"Failed to remove nodes: {str(e)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def remove_nodes(self, nodes: list[NodeOrIndex]): | |
"""Remove nodes from the skeleton. | |
Args: | |
nodes: A list of node names, indices, or `Node` objects to remove. | |
Notes: | |
This method handles updating the lookup caches necessary for indexing nodes | |
by name. | |
Any edges and symmetries that are connected to the removed nodes will also | |
be removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`instance.update_nodes()` on each instance that uses this skeleton. | |
""" | |
# Standardize input and make a pre-mutation copy before keys are changed. | |
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | |
# Remove nodes from the skeleton. | |
for node in rm_node_objs: | |
self.nodes.remove(node) | |
del self._name_to_node_cache[node.name] | |
# Remove edges connected to the removed nodes. | |
self.edges = [ | |
edge | |
for edge in self.edges | |
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | |
] | |
# Remove symmetries connected to the removed nodes. | |
self.symmetries = [ | |
symmetry | |
for symmetry in self.symmetries | |
if symmetry.nodes.isdisjoint(rm_node_objs) | |
] | |
# Update node index map. | |
self.rebuild_cache() | |
def remove_node(self, node: NodeOrIndex): | |
"""Remove a single node from the skeleton. | |
Args: | |
node: The node to remove. Can be specified as a string name, integer index, | |
or `Node` object. | |
Notes: | |
This method handles updating the lookup caches necessary for indexing nodes | |
by name. | |
Any edges and symmetries that are connected to the removed node will also be | |
removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained instances to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`Instance.update_skeleton()` on each instance that uses this skeleton. | |
""" | |
self.remove_nodes([node]) | |
def remove_nodes(self, nodes: list[NodeOrIndex]): | |
"""Remove nodes from the skeleton. | |
Args: | |
nodes: A list of node names, indices, or `Node` objects to remove. | |
Notes: | |
This method handles updating the lookup caches necessary for indexing nodes | |
by name. | |
Any edges and symmetries that are connected to the removed nodes will also | |
be removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`instance.update_nodes()` on each instance that uses this skeleton. | |
""" | |
# Standardize input and make a pre-mutation copy before keys are changed. | |
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | |
# Create copies of current state | |
new_nodes = self.nodes.copy() | |
new_edges = self.edges.copy() | |
new_symmetries = self.symmetries.copy() | |
new_name_cache = self._name_to_node_cache.copy() | |
try: | |
# Remove nodes | |
for node in rm_node_objs: | |
new_nodes.remove(node) | |
del new_name_cache[node.name] | |
# Update edges and symmetries | |
new_edges = [ | |
edge for edge in new_edges | |
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | |
] | |
new_symmetries = [ | |
symmetry for symmetry in new_symmetries | |
if symmetry.nodes.isdisjoint(rm_node_objs) | |
] | |
# Commit changes | |
self.nodes = new_nodes | |
self.edges = new_edges | |
self.symmetries = new_symmetries | |
self._name_to_node_cache = new_name_cache | |
self.rebuild_cache() | |
except Exception as e: | |
# Rollback on error | |
raise ValueError(f"Failed to remove nodes: {str(e)}") | |
def remove_node(self, node: NodeOrIndex): | |
"""Remove a single node from the skeleton. | |
Args: | |
node: The node to remove. Can be specified as a string name, integer index, | |
or `Node` object. | |
Notes: | |
This method handles updating the lookup caches necessary for indexing nodes | |
by name. | |
Any edges and symmetries that are connected to the removed node will also be | |
removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained instances to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`Instance.update_skeleton()` on each instance that uses this skeleton. | |
""" | |
self.remove_nodes([node]) |
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]): | ||
"""Rename nodes in the skeleton. | ||
|
||
Args: | ||
name_map: A dictionary mapping old node names to new node names. Keys can be | ||
specified as `Node` objects, integer indices, or string names. Values | ||
must be specified as string names. | ||
|
||
If a list of strings is provided of the same length as the current | ||
nodes, the nodes will be renamed to the names in the list in order. | ||
|
||
Raises: | ||
ValueError: If the new node names exist in the skeleton or if the old node | ||
names are not found in the skeleton. | ||
|
||
Notes: | ||
This method should always be used when renaming nodes in the skeleton as it | ||
handles updating the lookup caches necessary for indexing nodes by name. | ||
|
||
After renaming, instances using this skeleton **do NOT need to be updated** | ||
as the nodes are stored by reference in the skeleton, so changes are | ||
reflected automatically. | ||
|
||
Example: | ||
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")]) | ||
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | ||
>>> skel.node_names | ||
["X", "Y", "Z"] | ||
>>> skel.rename_nodes(["a", "b", "c"]) | ||
>>> skel.node_names | ||
["a", "b", "c"] | ||
""" | ||
if type(name_map) == list: | ||
if len(name_map) != len(self.nodes): | ||
raise ValueError( | ||
"List of new node names must be the same length as the current " | ||
"nodes." | ||
) | ||
name_map = {node: name for node, name in zip(self.nodes, name_map)} | ||
|
||
for old_name, new_name in name_map.items(): | ||
if type(old_name) == Node: | ||
old_name = old_name.name | ||
if type(old_name) == int: | ||
old_name = self.nodes[old_name].name | ||
|
||
if old_name not in self._name_to_node_cache: | ||
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | ||
if new_name in self._name_to_node_cache: | ||
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | ||
|
||
node = self._name_to_node_cache[old_name] | ||
node.name = new_name | ||
self._name_to_node_cache[new_name] = node | ||
del self._name_to_node_cache[old_name] | ||
|
||
def rename_node(self, old_name: NodeOrIndex, new_name: str): | ||
"""Rename a single node in the skeleton. | ||
|
||
Args: | ||
old_name: The name of the node to rename. Can also be specified as an | ||
integer index or `Node` object. | ||
new_name: The new name for the node. | ||
""" | ||
self.rename_nodes({old_name: new_name}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider making node renaming atomic.
While the implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state.
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]):
+ # Validate and prepare all changes first
+ updates = {}
if isinstance(name_map, list):
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current nodes."
)
name_map = {node: name for node, name in zip(self.nodes, name_map)}
for old_name, new_name in name_map.items():
if isinstance(old_name, Node):
old_name = old_name.name
if isinstance(old_name, int):
old_name = self.nodes[old_name].name
if old_name not in self._name_to_node_cache:
raise ValueError(f"Node '{old_name}' not found in the skeleton.")
if new_name in self._name_to_node_cache:
raise ValueError(f"Node '{new_name}' already exists in the skeleton.")
+ updates[old_name] = (self._name_to_node_cache[old_name], new_name)
+ # Apply all changes after validation
+ for old_name, (node, new_name) in updates.items():
node = self._name_to_node_cache[old_name]
node.name = new_name
self._name_to_node_cache[new_name] = node
del self._name_to_node_cache[old_name]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]): | |
"""Rename nodes in the skeleton. | |
Args: | |
name_map: A dictionary mapping old node names to new node names. Keys can be | |
specified as `Node` objects, integer indices, or string names. Values | |
must be specified as string names. | |
If a list of strings is provided of the same length as the current | |
nodes, the nodes will be renamed to the names in the list in order. | |
Raises: | |
ValueError: If the new node names exist in the skeleton or if the old node | |
names are not found in the skeleton. | |
Notes: | |
This method should always be used when renaming nodes in the skeleton as it | |
handles updating the lookup caches necessary for indexing nodes by name. | |
After renaming, instances using this skeleton **do NOT need to be updated** | |
as the nodes are stored by reference in the skeleton, so changes are | |
reflected automatically. | |
Example: | |
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")]) | |
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | |
>>> skel.node_names | |
["X", "Y", "Z"] | |
>>> skel.rename_nodes(["a", "b", "c"]) | |
>>> skel.node_names | |
["a", "b", "c"] | |
""" | |
if type(name_map) == list: | |
if len(name_map) != len(self.nodes): | |
raise ValueError( | |
"List of new node names must be the same length as the current " | |
"nodes." | |
) | |
name_map = {node: name for node, name in zip(self.nodes, name_map)} | |
for old_name, new_name in name_map.items(): | |
if type(old_name) == Node: | |
old_name = old_name.name | |
if type(old_name) == int: | |
old_name = self.nodes[old_name].name | |
if old_name not in self._name_to_node_cache: | |
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | |
if new_name in self._name_to_node_cache: | |
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | |
node = self._name_to_node_cache[old_name] | |
node.name = new_name | |
self._name_to_node_cache[new_name] = node | |
del self._name_to_node_cache[old_name] | |
def rename_node(self, old_name: NodeOrIndex, new_name: str): | |
"""Rename a single node in the skeleton. | |
Args: | |
old_name: The name of the node to rename. Can also be specified as an | |
integer index or `Node` object. | |
new_name: The new name for the node. | |
""" | |
self.rename_nodes({old_name: new_name}) | |
def rename_nodes(self, name_map: dict[NodeOrIndex, str] | list[str]): | |
"""Rename nodes in the skeleton. | |
Args: | |
name_map: A dictionary mapping old node names to new node names. Keys can be | |
specified as `Node` objects, integer indices, or string names. Values | |
must be specified as string names. | |
If a list of strings is provided of the same length as the current | |
nodes, the nodes will be renamed to the names in the list in order. | |
Raises: | |
ValueError: If the new node names exist in the skeleton or if the old node | |
names are not found in the skeleton. | |
Notes: | |
This method should always be used when renaming nodes in the skeleton as it | |
handles updating the lookup caches necessary for indexing nodes by name. | |
After renaming, instances using this skeleton **do NOT need to be updated** | |
as the nodes are stored by reference in the skeleton, so changes are | |
reflected automatically. | |
Example: | |
>>> skel = Skeleton(["A", "B", "C"], edges=[("A", "B"), ("B", "C")]) | |
>>> skel.rename_nodes({"A": "X", "B": "Y", "C": "Z"}) | |
>>> skel.node_names | |
["X", "Y", "Z"] | |
>>> skel.rename_nodes(["a", "b", "c"]) | |
>>> skel.node_names | |
["a", "b", "c"] | |
""" | |
# Validate and prepare all changes first | |
updates = {} | |
if isinstance(name_map, list): | |
if len(name_map) != len(self.nodes): | |
raise ValueError( | |
"List of new node names must be the same length as the current " | |
"nodes." | |
) | |
name_map = {node: name for node, name in zip(self.nodes, name_map)} | |
for old_name, new_name in name_map.items(): | |
if isinstance(old_name, Node): | |
old_name = old_name.name | |
if isinstance(old_name, int): | |
old_name = self.nodes[old_name].name | |
if old_name not in self._name_to_node_cache: | |
raise ValueError(f"Node '{old_name}' not found in the skeleton.") | |
if new_name in self._name_to_node_cache: | |
raise ValueError(f"Node '{new_name}' already exists in the skeleton.") | |
updates[old_name] = (self._name_to_node_cache[old_name], new_name) | |
# Apply all changes after validation | |
for old_name, (node, new_name) in updates.items(): | |
node = self._name_to_node_cache[old_name] | |
node.name = new_name | |
self._name_to_node_cache[new_name] = node | |
del self._name_to_node_cache[old_name] | |
def rename_node(self, old_name: NodeOrIndex, new_name: str): | |
"""Rename a single node in the skeleton. | |
Args: | |
old_name: The name of the node to rename. Can also be specified as an | |
integer index or `Node` object. | |
new_name: The new name for the node. | |
""" | |
self.rename_nodes({old_name: new_name}) |
🧰 Tools
🪛 Ruff
446-446: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
455-455: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
457-457: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
🧹 Outside diff range and nitpick comments (1)
sleap_io/model/labels.py (1)
468-512
: Add validation for node name uniqueness and list length.While the implementation is solid, consider adding these validations:
- Ensure new node names are unique when using dict input
- Validate that the list length matches the node count when using list input
def rename_nodes( self, name_map: dict[NodeOrIndex, str] | list[str], skeleton: Skeleton | None = None, ): if skeleton is None: if len(self.skeletons) != 1: raise ValueError( "Skeleton must be specified when there is more than one skeleton in " "the labels." ) skeleton = self.skeleton + # Validate input + if isinstance(name_map, list): + if len(name_map) != len(skeleton.nodes): + raise ValueError( + f"Expected {len(skeleton.nodes)} names but got {len(name_map)}." + ) + else: + new_names = set(name_map.values()) + if len(new_names) != len(name_map): + raise ValueError("New node names must be unique.") skeleton.rename_nodes(name_map)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (2)
sleap_io/model/labels.py
(2 hunks)sleap_io/model/skeleton.py
(6 hunks)
🧰 Additional context used
🪛 Ruff
sleap_io/model/skeleton.py
266-266: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
273-273: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
275-275: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
282-282: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
284-284: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
286-286: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
304-304: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
307-307: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
347-347: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
401-401: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
446-446: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
455-455: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
457-457: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
🔇 Additional comments (4)
sleap_io/model/skeleton.py (2)
99-107
: 🛠️ Refactor suggestion
Add type validation in the callback.
The callback mechanism for node updates could benefit from additional validation to ensure type safety.
def _nodes_on_setattr(self, attr, new_nodes):
"""Callback to update caches when nodes are set."""
+ if not isinstance(new_nodes, list):
+ raise TypeError(f"Expected list of nodes, got {type(new_nodes)}")
+ if not all(isinstance(node, (Node, str)) for node in new_nodes):
+ raise TypeError("All nodes must be Node objects or strings")
self.rebuild_cache(nodes=new_nodes)
return new_nodes
Likely invalid or redundant comment.
10-10
:
Remove unused import.
The typing
module is imported but not used directly in the code.
-import typing
Likely invalid or redundant comment.
sleap_io/model/labels.py (2)
463-466
: LGTM! Efficient implementation of instances property.
The generator expression implementation is memory-efficient for iterating over instances.
513-546
: LGTM! Well-implemented node removal with proper instance updates.
The method correctly handles:
- Skeleton validation
- Instance updates
- Documentation of side effects
def __getitem__(self, idx: NodeOrIndex) -> Node: | ||
"""Return a `Node` when indexing by name or integer.""" | ||
if type(idx) == int: | ||
return self.nodes[idx] | ||
elif type(idx) == str: | ||
return self._node_name_map[idx] | ||
return self._name_to_node_cache[idx] | ||
else: | ||
raise IndexError(f"Invalid indexing argument for skeleton: {idx}") | ||
|
||
def __contains__(self, node: NodeOrIndex) -> bool: | ||
"""Check if a node is in the skeleton.""" | ||
if type(node) == str: | ||
return node in self._name_to_node_cache | ||
elif type(node) == Node: | ||
return node in self.nodes | ||
elif type(node) == int: | ||
return 0 <= node < len(self.nodes) | ||
else: | ||
raise ValueError(f"Invalid node type for skeleton: {node}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace type() comparisons with isinstance() checks.
Multiple instances of type() ==
comparisons should be replaced with isinstance()
checks to properly handle inheritance and follow Python best practices.
Apply these changes throughout the file:
- if type(idx) == int:
+ if isinstance(idx, int):
return self.nodes[idx]
- elif type(idx) == str:
+ elif isinstance(idx, str):
return self._name_to_node_cache[idx]
- if type(node) == str:
+ if isinstance(node, str):
return node in self._name_to_node_cache
- elif type(node) == Node:
+ elif isinstance(node, Node):
return node in self.nodes
- elif type(node) == int:
+ elif isinstance(node, int):
return 0 <= node < len(self.nodes)
- if type(src) == tuple:
+ if isinstance(src, tuple):
src, dst = src
- if type(name_map) == list:
+ if isinstance(name_map, list):
if len(name_map) != len(self.nodes):
raise ValueError(
"List of new node names must be the same length as the current "
"nodes."
)
Also applies to: 364-374, 446-458
🧰 Tools
🪛 Ruff
273-273: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
275-275: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
282-282: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
284-284: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
286-286: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
NodeOrIndex = typing.Union[Node, str, int] | ||
# NodeOrIndex: TypeAlias = Node | str | int # py >= 3.10 | ||
# type NodeOrIndex = Node | str | int # py >= 3.12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider using more modern type alias syntax.
The code includes commented alternatives for newer Python versions. Since the code supports Python 3.10+, you could use the more modern syntax.
-NodeOrIndex = typing.Union[Node, str, int]
+NodeOrIndex: TypeAlias = Node | str | int
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
NodeOrIndex = typing.Union[Node, str, int] | |
# NodeOrIndex: TypeAlias = Node | str | int # py >= 3.10 | |
# type NodeOrIndex = Node | str | int # py >= 3.12 | |
NodeOrIndex: TypeAlias = Node | str | int | |
# NodeOrIndex: TypeAlias = Node | str | int # py >= 3.10 | |
# type NodeOrIndex = Node | str | int # py >= 3.12 |
def remove_nodes(self, nodes: list[NodeOrIndex]): | ||
"""Remove nodes from the skeleton. | ||
|
||
Args: | ||
nodes: A list of node names, indices, or `Node` objects to remove. | ||
|
||
Notes: | ||
This method handles updating the lookup caches necessary for indexing nodes | ||
by name. | ||
|
||
Any edges and symmetries that are connected to the removed nodes will also | ||
be removed. | ||
|
||
Warning: | ||
**This method does NOT update instances** that use this skeleton to reflect | ||
changes. | ||
|
||
It is recommended to use the `Labels.remove_nodes()` method which will | ||
update all contained to reflect the changes made to the skeleton. | ||
|
||
To manually update instances after this method is called, call | ||
`instance.update_nodes()` on each instance that uses this skeleton. | ||
""" | ||
# Standardize input and make a pre-mutation copy before keys are changed. | ||
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | ||
|
||
# Remove nodes from the skeleton. | ||
for node in rm_node_objs: | ||
self.nodes.remove(node) | ||
del self._name_to_node_cache[node.name] | ||
|
||
# Remove edges connected to the removed nodes. | ||
self.edges = [ | ||
edge | ||
for edge in self.edges | ||
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | ||
] | ||
|
||
# Remove symmetries connected to the removed nodes. | ||
self.symmetries = [ | ||
symmetry | ||
for symmetry in self.symmetries | ||
if symmetry.nodes.isdisjoint(rm_node_objs) | ||
] | ||
|
||
# Update node index map. | ||
self.rebuild_cache() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider making node removal atomic.
While the node removal implementation is thorough, it could fail mid-operation leaving the skeleton in an inconsistent state. Consider implementing an atomic operation pattern.
def remove_nodes(self, nodes: list[NodeOrIndex]):
# Standardize input and make a pre-mutation copy before keys are changed.
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes]
+
+ # Create copies of current state
+ new_nodes = self.nodes.copy()
+ new_edges = self.edges.copy()
+ new_symmetries = self.symmetries.copy()
+ new_name_cache = self._name_to_node_cache.copy()
+
+ try:
+ # Remove nodes
+ for node in rm_node_objs:
+ new_nodes.remove(node)
+ del new_name_cache[node.name]
+
+ # Update edges and symmetries
+ new_edges = [
+ edge for edge in new_edges
+ if edge.source not in rm_node_objs and edge.destination not in rm_node_objs
+ ]
+ new_symmetries = [
+ symmetry for symmetry in new_symmetries
+ if symmetry.nodes.isdisjoint(rm_node_objs)
+ ]
+
+ # Commit changes
+ self.nodes = new_nodes
+ self.edges = new_edges
+ self.symmetries = new_symmetries
+ self._name_to_node_cache = new_name_cache
+ self.rebuild_cache()
+ except Exception as e:
+ # Rollback on error
+ raise ValueError(f"Failed to remove nodes: {str(e)}")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def remove_nodes(self, nodes: list[NodeOrIndex]): | |
"""Remove nodes from the skeleton. | |
Args: | |
nodes: A list of node names, indices, or `Node` objects to remove. | |
Notes: | |
This method handles updating the lookup caches necessary for indexing nodes | |
by name. | |
Any edges and symmetries that are connected to the removed nodes will also | |
be removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`instance.update_nodes()` on each instance that uses this skeleton. | |
""" | |
# Standardize input and make a pre-mutation copy before keys are changed. | |
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | |
# Remove nodes from the skeleton. | |
for node in rm_node_objs: | |
self.nodes.remove(node) | |
del self._name_to_node_cache[node.name] | |
# Remove edges connected to the removed nodes. | |
self.edges = [ | |
edge | |
for edge in self.edges | |
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | |
] | |
# Remove symmetries connected to the removed nodes. | |
self.symmetries = [ | |
symmetry | |
for symmetry in self.symmetries | |
if symmetry.nodes.isdisjoint(rm_node_objs) | |
] | |
# Update node index map. | |
self.rebuild_cache() | |
def remove_nodes(self, nodes: list[NodeOrIndex]): | |
"""Remove nodes from the skeleton. | |
Args: | |
nodes: A list of node names, indices, or `Node` objects to remove. | |
Notes: | |
This method handles updating the lookup caches necessary for indexing nodes | |
by name. | |
Any edges and symmetries that are connected to the removed nodes will also | |
be removed. | |
Warning: | |
**This method does NOT update instances** that use this skeleton to reflect | |
changes. | |
It is recommended to use the `Labels.remove_nodes()` method which will | |
update all contained to reflect the changes made to the skeleton. | |
To manually update instances after this method is called, call | |
`instance.update_nodes()` on each instance that uses this skeleton. | |
""" | |
# Standardize input and make a pre-mutation copy before keys are changed. | |
rm_node_objs = [self.require_node(node, add_missing=False) for node in nodes] | |
# Create copies of current state | |
new_nodes = self.nodes.copy() | |
new_edges = self.edges.copy() | |
new_symmetries = self.symmetries.copy() | |
new_name_cache = self._name_to_node_cache.copy() | |
try: | |
# Remove nodes | |
for node in rm_node_objs: | |
new_nodes.remove(node) | |
del new_name_cache[node.name] | |
# Update edges and symmetries | |
new_edges = [ | |
edge for edge in new_edges | |
if edge.source not in rm_node_objs and edge.destination not in rm_node_objs | |
] | |
new_symmetries = [ | |
symmetry for symmetry in new_symmetries | |
if symmetry.nodes.isdisjoint(rm_node_objs) | |
] | |
# Commit changes | |
self.nodes = new_nodes | |
self.edges = new_edges | |
self.symmetries = new_symmetries | |
self._name_to_node_cache = new_name_cache | |
self.rebuild_cache() | |
except Exception as e: | |
# Rollback on error | |
raise ValueError(f"Failed to remove nodes: {str(e)}") |
sleap_io/model/skeleton.py
Outdated
def add_edge( | ||
self, | ||
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex], | ||
dst: NodeOrIndex | None = None, | ||
): | ||
"""Add an `Edge` to the skeleton. | ||
|
||
Args: | ||
src: The source `Node` or name of the source node. | ||
dst: The destination `Node` or name of the destination node. | ||
src: The source node specified as a `Node`, name or index. | ||
dst: The destination node specified as a `Node`, name or index. | ||
""" | ||
edge = None | ||
if type(src) == tuple: | ||
src, dst = src | ||
|
||
if isinstance(src, NodeOrIndex): | ||
if not isinstance(dst, NodeOrIndex): | ||
raise ValueError("Destination node must be specified.") | ||
|
||
src = self.require_node(src) | ||
dst = self.require_node(dst) | ||
edge = Edge(src, dst) | ||
|
||
if type(src) == Edge: | ||
edge = src | ||
if edge not in self.edges: | ||
self.edges.append(edge) | ||
if edge.source not in self.nodes: | ||
self.add_node(edge.source) | ||
if edge.destination not in self.nodes: | ||
self.add_node(edge.destination) | ||
return | ||
|
||
if type(src) == str or type(src) == Node: | ||
try: | ||
src = self.index(src) | ||
except KeyError: | ||
self.add_node(src) | ||
src = self.index(src) | ||
|
||
if type(dst) == str or type(dst) == Node: | ||
try: | ||
dst = self.index(dst) | ||
except KeyError: | ||
self.add_node(dst) | ||
dst = self.index(dst) | ||
|
||
edge = Edge(self.nodes[src], self.nodes[dst]) | ||
|
||
if edge not in self.edges: | ||
self.edges.append(edge) | ||
|
||
def add_edges(self, edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]): | ||
"""Add multiple `Edge`s to the skeleton. | ||
|
||
Args: | ||
edges: A list of `Edge` objects or 2-tuples of source and destination nodes. | ||
""" | ||
for edge in edges: | ||
self.add_edge(edge) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Consider adding cycle detection in edge addition.
While the edge management is well-implemented, adding cycle detection would prevent creation of cyclic dependencies in the skeleton graph.
def add_edge(
self,
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex],
dst: NodeOrIndex | None = None,
):
edge = None
if isinstance(src, tuple):
src, dst = src
if isinstance(src, NodeOrIndex):
if not isinstance(dst, NodeOrIndex):
raise ValueError("Destination node must be specified.")
src = self.require_node(src)
dst = self.require_node(dst)
+ # Check for cycles
+ if self._would_create_cycle(src, dst):
+ raise ValueError("Adding this edge would create a cycle in the skeleton graph")
edge = Edge(src, dst)
if isinstance(src, Edge):
edge = src
if edge not in self.edges:
self.edges.append(edge)
+def _would_create_cycle(self, src: Node, dst: Node) -> bool:
+ """Check if adding an edge would create a cycle."""
+ visited = set()
+ def dfs(node: Node) -> bool:
+ if node == src:
+ return True
+ visited.add(node)
+ for edge in self.edges:
+ if edge.source == node and edge.destination not in visited:
+ if dfs(edge.destination):
+ return True
+ return False
+ return dfs(dst)
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def add_edge( | |
self, | |
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex], | |
dst: NodeOrIndex | None = None, | |
): | |
"""Add an `Edge` to the skeleton. | |
Args: | |
src: The source `Node` or name of the source node. | |
dst: The destination `Node` or name of the destination node. | |
src: The source node specified as a `Node`, name or index. | |
dst: The destination node specified as a `Node`, name or index. | |
""" | |
edge = None | |
if type(src) == tuple: | |
src, dst = src | |
if isinstance(src, NodeOrIndex): | |
if not isinstance(dst, NodeOrIndex): | |
raise ValueError("Destination node must be specified.") | |
src = self.require_node(src) | |
dst = self.require_node(dst) | |
edge = Edge(src, dst) | |
if type(src) == Edge: | |
edge = src | |
if edge not in self.edges: | |
self.edges.append(edge) | |
if edge.source not in self.nodes: | |
self.add_node(edge.source) | |
if edge.destination not in self.nodes: | |
self.add_node(edge.destination) | |
return | |
if type(src) == str or type(src) == Node: | |
try: | |
src = self.index(src) | |
except KeyError: | |
self.add_node(src) | |
src = self.index(src) | |
if type(dst) == str or type(dst) == Node: | |
try: | |
dst = self.index(dst) | |
except KeyError: | |
self.add_node(dst) | |
dst = self.index(dst) | |
edge = Edge(self.nodes[src], self.nodes[dst]) | |
if edge not in self.edges: | |
self.edges.append(edge) | |
def add_edges(self, edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]): | |
"""Add multiple `Edge`s to the skeleton. | |
Args: | |
edges: A list of `Edge` objects or 2-tuples of source and destination nodes. | |
""" | |
for edge in edges: | |
self.add_edge(edge) | |
def add_edge( | |
self, | |
src: NodeOrIndex | Edge | tuple[NodeOrIndex, NodeOrIndex], | |
dst: NodeOrIndex | None = None, | |
): | |
"""Add an `Edge` to the skeleton. | |
Args: | |
src: The source node specified as a `Node`, name or index. | |
dst: The destination node specified as a `Node`, name or index. | |
""" | |
edge = None | |
if isinstance(src, tuple): | |
src, dst = src | |
if isinstance(src, NodeOrIndex): | |
if not isinstance(dst, NodeOrIndex): | |
raise ValueError("Destination node must be specified.") | |
src = self.require_node(src) | |
dst = self.require_node(dst) | |
# Check for cycles | |
if self._would_create_cycle(src, dst): | |
raise ValueError("Adding this edge would create a cycle in the skeleton graph") | |
edge = Edge(src, dst) | |
if isinstance(src, Edge): | |
edge = src | |
if edge not in self.edges: | |
self.edges.append(edge) | |
def _would_create_cycle(self, src: Node, dst: Node) -> bool: | |
"""Check if adding an edge would create a cycle.""" | |
visited = set() | |
def dfs(node: Node) -> bool: | |
if node == src: | |
return True | |
visited.add(node) | |
for edge in self.edges: | |
if edge.source == node and edge.destination not in visited: | |
if dfs(edge.destination): | |
return True | |
return False | |
return dfs(dst) | |
def add_edges(self, edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]]): | |
"""Add multiple `Edge`s to the skeleton. | |
Args: | |
edges: A list of `Edge` objects or 2-tuples of source and destination nodes. | |
""" | |
for edge in edges: | |
self.add_edge(edge) |
🧰 Tools
🪛 Ruff
364-364: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
375-375: Use is
and is not
for type comparisons, or isinstance()
for isinstance checks
(E721)
if node_map is None: | ||
node_map = {} | ||
for old_node in old_skeleton.nodes: | ||
for new_node in new_skeleton.nodes: | ||
if old_node.name == new_node.name: | ||
node_map[old_node] = new_node | ||
break | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Optimize node mapping creation performance.
The current implementation uses nested loops to create the node mapping, which has O(n²) complexity. Consider using a dictionary for O(n) lookup:
if node_map is None:
node_map = {}
+ new_node_map = {node.name: node for node in new_skeleton.nodes}
for old_node in old_skeleton.nodes:
- for new_node in new_skeleton.nodes:
- if old_node.name == new_node.name:
- node_map[old_node] = new_node
- break
+ if old_node.name in new_node_map:
+ node_map[old_node] = new_node_map[old_node.name]
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if node_map is None: | |
node_map = {} | |
for old_node in old_skeleton.nodes: | |
for new_node in new_skeleton.nodes: | |
if old_node.name == new_node.name: | |
node_map[old_node] = new_node | |
break | |
else: | |
if node_map is None: | |
node_map = {} | |
new_node_map = {node.name: node for node in new_skeleton.nodes} | |
for old_node in old_skeleton.nodes: | |
if old_node.name in new_node_map: | |
node_map[old_node] = new_node_map[old_node.name] | |
else: |
This PR builds on #135 to add a number of methods for
Skeleton
manipulation, addressing #123 and more.API changes:
Skeleton
__contains__(node: NodeOrIndex)
: ReturnsTrue
if a node exists in the skeleton.rebuild_cache()
: Method allowing explicit regeneration of the caching attributes from the nodes._name_to_node_cache
and_node_to_ind_cache
, better reflecting the mapping directionality.require_node(node: NodeOrIndex, add_missing: bool = True)
: Returns aNode
given aNode
,int
orstr
. Ifadd_missing
isTrue
, the node is added or created, otherwise anIndexError
is raised. This is helpful for flexibly converting between node representations with convenient existence handling.add_nodes(list[Node | str])
: Convenience method to add a list of nodes.add_edges(edges: list[Edge | tuple[NodeOrIndex, NodeOrIndex]])
: Convenience method to add a list of edges.rename_nodes(name_map: dict[NodeOrIndex, str] | list[str])
: Method to rename nodes either by specifying a potentially partial mapping from node(s) to new name(s), or a list of new names. Handles updating both theNode.name
attributes and the cache.rename_node(old_name: NodeOrIndex, new_name: str)
: Shorter syntax for renaming a single node.remove_nodes(nodes: list[NodeOrIndex])
: Method for removing nodes from the skeleton and updating caches. Does NOT update corresponding instances.remove_node(node: NodeOrIndex)
: Shorter syntax for removing a single node.reorder_nodes(new_order: list[NodeOrIndex])
: Method for setting the order of the nodes within the skeleton with cache updating. Does NOT update corresponding instances.Instance
/PredictedInstance
update_skeleton()
: Updates thepoints
attribute on the instance to reflect changes in the associated skeleton (removed nodes and reordering). This is called internally after updating the skeleton from theLabels
level, but also exposed for more complex data manipulation workflows.replace_skeleton(new_skeleton: Skeleton, node_map: dict[NodeOrIndex, NodeOrIndex] | None = None, rev_node_map: dict[NodeOrIndex, NodeOrIndex] | None = None)
: Method to replace the skeleton on the instance with optional capability to specify a node mapping so that data stored in thepoints
attribute is retained and associated with the right nodes in the new skeleton. Mapping is specified innode_map
from old to new nodes and defaults to mapping between node objects with the same name.rev_node_map
maps new nodes to old nodes and is used internally when calling from theLabels
level as it bypasses validation.Labels
instances
: Convenience property that returns a generator that loops over all labeled frames and returns all instances. This can be lazily iterated over without having to construct a huge list of all the instances.rename_nodes(name_map: dict[NodeOrIndex, str] | list[str], skeleton: Skeleton | None = None)
: Method to rename nodes in a specified skeleton within the labels.remove_nodes(nodes: list[NodeOrIndex], skeleton: Skeleton | None = None)
: Method to remove nodes in a specified skeleton within the labels. This also updates all instances associated with the skeleton, removing point data for the removed nodes.reorder_nodes(new_order: list[NodeOrIndex], skeleton: Skeleton | None = None)
: Method to reorder nodes in a specified skeleton within the labels. This also updates all instances associated with the skeleton, reordering point data for the nodes.replace_skeleton(new_skeleton: Skeleton, old_skeleton: Skeleton | None = None, node_map: dict[NodeOrIndex, NodeOrIndex] | None = None)
: Method to replace a skeleton entirely within the labels, updating all instances associated with the old skeleton to use the new skeleton, optionally with node remapping to retain previous point data.Summary by CodeRabbit
Release Notes
New Features
Skeleton
andLabels
classes.instances
property in theLabels
class for easier access to instance data.Instance
class with methods to update and replace skeletons.Bug Fixes
Tests
Skeleton
,Labels
, andInstance
classes, validating new functionalities and error conditions.