Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for the Kumar Lab's JABS format #63

Merged
merged 20 commits into from
Sep 20, 2023

Conversation

SkepticRaven
Copy link
Contributor

@SkepticRaven SkepticRaven commented Aug 31, 2023

Adds conversion of JABS pose files into SLEAP.

List of things supported:

  1. Read JABS Pose format v2-v5 into SLEAP
  2. Write JABS Pose format v2-v5 out of SLEAP
    a. 1 pose file per video (as per JABS format)

TODO:

  1. Add support for our training data? -- Deferring to future PR
  2. Static objects not handled
    a. Duplicate across all frames or just carry over to first?
    b. Writing out assumes all data is same without looking at skeleton
  3. Lazy implementations need to become generic
    a. Track support requires integer names (convert to factorize)
    b. Shapes of data may not be guaranteed (eg frame matrix should match shape of video, not annotations available)
    c. Attribute data not preserved... is there a way to preserve it? px_to_cm is important -- Deferring to future PR
    d. Tests/example data?
    e. Enforce 12 keypoint names such that data is compatible with JABS related tools (eg https://github.com/KumarLabJax/JABS-behavior-classifier and https://github.com/KumarLabJax/gaitanalysis)
  4. Run black
  5. v3 writer needs to break apart non-continuous tracks

Summary by CodeRabbit

  • New Feature: Added support for JABS pose file format. Users can now load and save JABS-style predictions using the new load_jabs and save_jabs functions.
  • New Feature: Introduced a module for handling direct I/O operations with JABS files, including reading and writing JABS pose data, converting labels to JABS format, and generating default skeletons and symmetries.
  • Test: Added comprehensive test coverage for the new JABS I/O functionality, ensuring its correctness and behavior.
  • Chore: Updated test fixtures to include real JABS data for testing purposes.

@talmo
Copy link
Contributor

talmo commented Sep 6, 2023

Hi @SkepticRaven! Thanks for the initial PR -- this is looking great!!

As far as static objects and other attributes (px_to_cm): These are useful additions and I believe that we could add them to the core data models.

For the sake of keeping this PR manageable, I propose that you go ahead and finish implementing the basic support for JABS formatted files and we'll do a second PR that adds new data model features. The other formats that don't support those can just ignore them.

@codecov
Copy link

codecov bot commented Sep 6, 2023

Codecov Report

Merging #63 (3fea3bd) into main (e475a47) will decrease coverage by 0.23%.
The diff coverage is 93.06%.

@@            Coverage Diff             @@
##             main      #63      +/-   ##
==========================================
- Coverage   94.68%   94.45%   -0.23%     
==========================================
  Files          12       13       +1     
  Lines        1241     1442     +201     
==========================================
+ Hits         1175     1362     +187     
- Misses         66       80      +14     
Files Changed Coverage Δ
sleap_io/__init__.py 100.00% <ø> (ø)
sleap_io/io/jabs.py 92.89% <92.89%> (ø)
sleap_io/io/main.py 95.65% <100.00%> (+0.91%) ⬆️

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@SkepticRaven
Copy link
Contributor Author

Hey @talmo.
For static objects, they can be added as just a Label associated with a different Skeleton (that's not named "Mouse"). Having an attribute that says "this thing doesn't move" would be a nice addition at a future time point. I think for now, I'll implement it such that the data appears for the first frame (per video) only.

I agree that attributes (eg px_to_cm) don't yet have a place -- can leave it out of this PR and add it when SLEAP has a place for it!

As for our training data, I think I'll leave that out of this pull request. Our single mouse training data is in a h5 format (originating from CVAT). Our multi-mouse data has been converted between multiple formats (including SLEAP and Label-Studio, both already supported here) and we haven't released it yet.

@SkepticRaven SkepticRaven marked this pull request as ready for review September 18, 2023 13:45
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 18, 2023

Walkthrough

This pull request introduces support for JABS pose file format in the SLEAP codebase. It adds new functions for reading and writing JABS files, converting labels to JABS format, and generating default skeletons and symmetries. The changes also include test fixtures and test cases to validate the functionality of these additions.

Changes

File Summary
sleap_io/__init__.py Exported two new functions, load_jabs and save_jabs.
sleap_io/io/jabs.py Added a module for handling I/O operations with JABS files, including reading, writing, and converting labels.
sleap_io/io/main.py Introduced support for JABS pose file format with load_jabs and save_jabs functions.
tests/conftest.py Imported a new module "jabs" in the test fixtures.
tests/fixtures/jabs.py Added two new fixtures, jabs_real_data_v2 and jabs_real_data_v5, for testing.
tests/io/test_jabs.py Added test functions for various functions in the sleap_io.io.jabs module.
tests/io/test_main.py Added load_jabs and save_jabs functions and a new test case test_jabs to verify their functionality.

🐇💻

Code hopping through the night,

Adding features, pure delight.

JABS files now in sight,

Tests ensure they're just right! 🎉


Tips

Chat with CodeRabbit Bot (@coderabbitai)

  • Mention @coderabbitai in any review comment for bot assistance.
  • Note: Review comments are made on code diffs or files, not on the PR overview.

Pause Incremental Reviews

  • Insert @coderabbitai: ignore in the PR description to halt the bot's ongoing reviews. Remove the line to resume.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 9

Commits Files that changed from the base of the PR and between e475a47 and 70042d3 commits.
Files selected for processing (7)
  • sleap_io/init.py (1 hunks)
  • sleap_io/io/jabs.py (1 hunks)
  • sleap_io/io/main.py (2 hunks)
  • tests/conftest.py (1 hunks)
  • tests/fixtures/jabs.py (1 hunks)
  • tests/io/test_jabs.py (1 hunks)
  • tests/io/test_main.py (2 hunks)
Files skipped from review due to trivial changes (3)
  • sleap_io/init.py
  • tests/conftest.py
  • tests/io/test_jabs.py
Additional comments (Suppressed): 4
tests/fixtures/jabs.py (1)
  • 1-30: The fixtures jabs_real_data_v2 and jabs_real_data_v5 are well defined and provide clear documentation about the data they return. The paths returned by these fixtures seem to be hardcoded. Ensure that these files exist in the repository and are placed in the correct directory.
sleap_io/io/main.py (2)
  • 79-105: The new functions load_jabs and save_jabs have been added to handle JABS file format. The function signatures and documentation look good. However, please ensure that the skeleton parameter in load_jabs and pose_version parameter in save_jabs are handled correctly in all calls to these functions throughout the codebase.

  • 4-5: The import statement has been updated to include the jabs module from sleap_io.io. This is necessary for the newly added JABS support.

tests/io/test_main.py (1)
  • 6-11: The import statements have been updated to include the new load_jabs and save_jabs functions. Ensure that these functions are correctly implemented in the imported module.

Comment on lines 43 to 52
def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5):
labels_single = load_jabs(jabs_real_data_v2)
assert isinstance(labels_single, Labels)
save_jabs(labels_single, 2, tmp_path)

labels_multi = load_jabs(jabs_real_data_v5)
assert isinstance(labels_multi, Labels)
save_jabs(labels_multi, 3, tmp_path)
save_jabs(labels_multi, 4, tmp_path)
save_jabs(labels_multi, 5, tmp_path)
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a new test function for the JABS file format support. It tests both loading and saving of JABS files. However, there are no assertions to verify the correctness of the saved files. Consider adding checks to ensure that the saved files match the expected output.

    save_jabs(labels_single, 2, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "labels_single.jabs")
+    assert labels_single == loaded_labels

    save_jabs(labels_multi, 3, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "labels_multi_v3.jabs")
+    assert labels_multi == loaded_labels

    save_jabs(labels_multi, 4, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "labels_multi_v4.jabs")
+    assert labels_multi == loaded_labels

    save_jabs(labels_multi, 5, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "labels_multi_v5.jabs")
+    assert labels_multi == loaded_labels

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

v2 check could work, but the remainder do not guarantee sorting of data. On first load/save, the data is equivalent, but not identical. A second round of loading + saving might allow this check to work.
v3 and v4 checks definitely won't work, since data is intentionally discarded.

Also, the filenames are generated based on videos in the annotation and are not labels_*.jabs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkepticRaven, thank you for your feedback. I understand the challenges with the v3 and v4 checks due to data being intentionally discarded. In this case, we could modify our tests to only check for the presence of essential data rather than exact equality.

As for the filenames, I apologize for the oversight. We should indeed use the actual filenames generated based on the videos in the annotation. Here's a revised version of the test:

def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5):
    labels_single = load_jabs(jabs_real_data_v2)
    assert isinstance(labels_single, Labels)
    save_jabs(labels_single, 2, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video.jabs")
+    assert labels_single.essential_data == loaded_labels.essential_data

    labels_multi = load_jabs(jabs_real_data_v5)
    assert isinstance(labels_multi, Labels)
    save_jabs(labels_multi, 3, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video_v3.jabs")
+    assert labels_multi.essential_data == loaded_labels.essential_data

    save_jabs(labels_multi, 4, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video_v4.jabs")
+    assert labels_multi.essential_data == loaded_labels.essential_data

    save_jabs(labels_multi, 5, tmp_path)
+    loaded_labels = load_jabs(tmp_path / "actual_filename_based_on_video_v5.jabs")
+    assert labels_multi.essential_data == loaded_labels.essential_data

Please replace actual_filename_based_on_video with the correct filename based on the video in the annotation. The essential_data attribute should be replaced with the actual attribute or method that checks for the presence of essential data in the Labels object.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As an alternative, can we just minimally test for the len(labels) being equivalent before and after the roundtrip?

Full equality testing will be a bit onerous currently.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like humans are talking. I'll be quiet for now. If you need me again, just tag @coderabbitai in a new comment.

sleap_io/io/jabs.py Show resolved Hide resolved
Comment on lines 84 to 163
frames: List[LabeledFrame] = []
# Video name is the pose file minus the suffix
video_name = re.sub(r"(_pose_est_v[2-6])?\.h5", ".avi", labels_path)
if not skeleton:
skeleton = JABS_DEFAULT_SKELETON
tracks = {}

with h5py.File(labels_path, "r") as pose_file:
num_frames = pose_file["poseest/points"].shape[0]
try:
pose_version = pose_file["poseest"].attrs["version"][0]
except Exception:
pose_version = 2
tracks[1] = Track("1")
data_shape = pose_file["poseest/points"].shape
assert (
len(data_shape) == 3
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}"
# Change field name for newer pose formats
if pose_version == 3:
id_key = "instance_track_id"
elif pose_version > 3:
id_key = "instance_embed_id"
max_ids = pose_file["poseest/points"].shape[1]

for frame_idx in range(num_frames):
instances = []
pose_data = pose_file["poseest/points"][frame_idx, ...]
# JABS stores y,x for poses
pose_data = np.flip(pose_data, axis=-1)
pose_conf = pose_file["poseest/confidence"][frame_idx, ...]
# single animal case
if pose_version == 2:
new_instance = prediction_to_instance(
pose_data, pose_conf, skeleton, tracks[1]
)
instances.append(new_instance)
# multi-animal case
if pose_version > 2:
pose_ids = pose_file["poseest/" + id_key][frame_idx, ...]
# pose_v3 uses another field to describe the number of valid poses
if pose_version == 3:
max_ids = pose_file["poseest/instance_count"][frame_idx]
for cur_id in range(max_ids):
# v4+ uses reserved values for invalid/unused poses
# Note: ignores 'poseest/id_mask' to keep predictions that were not assigned an id
if pose_version > 3 and pose_ids[cur_id] <= 0:
continue
if pose_ids[cur_id] not in tracks.keys():
tracks[pose_ids[cur_id]] = Track(str(pose_ids[cur_id]))
new_instance = prediction_to_instance(
pose_data[cur_id],
pose_conf[cur_id],
skeleton,
tracks[pose_ids[cur_id]],
)
if new_instance:
instances.append(new_instance)
# Static objects
if (
frame_idx == 0
and pose_version >= 5
and "static_objects" in pose_file.keys()
):
present_objects = pose_file["static_objects"].keys()
for cur_object in present_objects:
object_keypoints = pose_file["static_objects/" + cur_object][:]
object_skeleton = make_simple_skeleton(
cur_object, object_keypoints.shape[0]
)
new_instance = prediction_to_instance(
object_keypoints,
np.ones(object_keypoints.shape[:-1]),
object_skeleton,
)
if new_instance:
instances.append(new_instance)
frame_label = LabeledFrame(Video(video_name), frame_idx, instances)
frames.append(frame_label)
return Labels(frames)
Copy link
Contributor

@coderabbitai coderabbitai bot Sep 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is quite long and does multiple things: it reads data from a file, processes it, and creates various objects. Consider breaking it down into smaller functions for better readability and maintainability. For example, you could have separate functions for reading the file, processing the data, and creating the objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@coderabbitai: ignore
"creating the objects" is already separated as prediction_to_instance.
Reading and processing the data is coupled to reduce RAM footprint (only have 2 copies in the frame loop instead of video loop).

The checks and keys for reading could move to another function, but I don't have the time to make that refactor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SkepticRaven: I understand your concerns about memory usage and time constraints. However, it's important to note that maintainability and readability of the code are also crucial aspects of software development.

While the prediction_to_instance function does indeed separate the creation of objects, the read_labels function still performs multiple tasks which could be further broken down. For instance, handling different pose versions and processing static objects could potentially be moved into their own functions.

Here's a quick example of how you might refactor the pose version handling:

def handle_pose_version(pose_file, num_frames):
    try:
        pose_version = pose_file["poseest"].attrs["version"][0]
    except Exception:
        pose_version = 2
        tracks[1] = Track("1")
        data_shape = pose_file["poseest/points"].shape
        assert (
            len(data_shape) == 3
        ), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}"
    return pose_version

# In read_labels function
pose_version = handle_pose_version(pose_file, num_frames)

This is just a suggestion and may not perfectly fit your needs, but hopefully it gives you an idea of how you might approach this. Even small steps towards cleaner, more modular code can have significant benefits in the long run.

Comment on lines 96 to 101
pose_version = 2
tracks[1] = Track("1")
data_shape = pose_file["poseest/points"].shape
assert (
len(data_shape) == 3
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no specific error handling here. If an exception occurs when trying to read the pose version, the code assumes that the pose version is 2. However, the exception could be due to other issues (e.g., file not found, permission denied). It would be better to catch specific exceptions and handle them appropriately.

Comment on lines 203 to 208
if confidence[i] > 0.001:
points[cur_node] = Point(
data[i, 0],
data[i, 1],
visible=True,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The confidence threshold (0.001) is hardcoded. Consider making this a parameter of the function so that it can be easily adjusted if needed.

Comment on lines 330 to 331
if os.path.exists(out_filename):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code checks if the output file already exists but doesn't do anything with this information. If overwriting is not desired, consider raising an error or warning if the file already exists.

sleap_io/io/jabs.py Show resolved Hide resolved
Copy link
Contributor

@talmo talmo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM! I'll leave it open for another day or until I hear back in case you want to make more changes, then we'll merge.

Thanks!!

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 6

Commits Files that changed from the base of the PR and between 70042d3 and 31ec9e3.
Files selected for processing (2)
  • sleap_io/io/jabs.py (1 hunks)
  • tests/io/test_main.py (2 hunks)
Additional comments (Suppressed): 1
tests/io/test_main.py (1)
  • 6-11: The import statements have been updated to include the new load_jabs and save_jabs functions. Ensure that these functions are correctly implemented in the imported module.

Comment on lines 91 to 94
if not os.access(labels_path, os.F_OK):
raise PermissionError(f"{labels_path} cannot be accessed.")
if not os.access(labels_path, os.R_OK):
raise FileNotFoundError(f"{labels_path} doesn't exist.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The order of the checks for file existence and accessibility is incorrect. The check for file existence should be performed before checking for file access permissions. If the file does not exist, there's no point in checking its permissions.

-    if not os.access(labels_path, os.F_OK):
-        raise PermissionError(f"{labels_path} cannot be accessed.")
-    if not os.access(labels_path, os.R_OK):
-        raise FileNotFoundError(f"{labels_path} doesn't exist.")
+    if not os.path.exists(labels_path):
+        raise FileNotFoundError(f"{labels_path} doesn't exist.")
+    if not os.access(labels_path, os.R_OK):
+        raise PermissionError(f"{labels_path} cannot be accessed.")

Comment on lines +202 to +204
assert (
len(skeleton.nodes) == data.shape[0]
), f"Skeleton ({len(skeleton.nodes)}) does not match number of keypoints ({data.shape[0]})"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The assertion error message seems to be swapped. It should indicate that the number of keypoints does not match the skeleton nodes.

-    assert (
-        len(skeleton.nodes) == data.shape[0]
-    ), f"Skeleton ({len(skeleton.nodes)}) does not match number of keypoints ({data.shape[0]})"
+    assert (
+        len(skeleton.nodes) == data.shape[0]
+    ), f"Number of keypoints ({data.shape[0]}) does not match skeleton nodes ({len(skeleton.nodes)})"

Comment on lines +284 to +286
warnings.warn(
f"JABS format only supports 12 keypoints for mice. Skipping storage of instance on frame {label.frame_idx} with {len(instance.points)} keypoints."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This warning message could potentially flood the console if there are many instances with a different number of keypoints than expected. Consider aggregating these warnings and printing a summary instead of individual messages for each instance.

Comment on lines +327 to +347
for video in labels.videos:
converted_labels = convert_labels(labels, video)
out_filename = (
os.path.splitext(video.filename)[0] + f"_pose_est_v{pose_version}.h5"
)
if root_folder:
out_filename = os.path.join(root_folder, out_filename)
os.makedirs(os.path.dirname(out_filename), exist_ok=True)
if os.path.exists(out_filename):
warnings.warn(f"Skipping {out_filename} because it already exists.")
continue
if pose_version == 2:
write_jabs_v2(converted_labels, out_filename)
elif pose_version == 3:
write_jabs_v3(converted_labels, out_filename)
elif pose_version == 4:
write_jabs_v4(converted_labels, out_filename)
elif pose_version == 5:
write_jabs_v5(converted_labels, out_filename)
else:
raise NotImplementedError(f"Pose format {pose_version} not supported.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write_labels function currently supports pose versions 2 to 5. However, it does not handle cases where an unsupported version is passed. Adding a default case to handle unsupported versions would improve error handling.

+    else:
+        raise ValueError(f"Unsupported pose version: {pose_version}")

Comment on lines +516 to +538
def write_jabs_v5(data: dict, filename: str):
"""Write JABS pose file v5 data to file.

Writes multi-mouse pose, longterm identity, and static object data.

Args:
data: Dictionary of JABS data generated from convert_labels
filename: Filename to write data to
"""
# v5 extends v4
write_jabs_v4(data, filename)
with h5py.File(filename, "a") as h5:
pose_grp = h5.require_group("poseest")
pose_grp.attrs.update({"version": [5, 0]})
if "static_objects" in data.keys():
object_grp = h5.require_group("static_objects")
for object_key, object_keypoints in data["static_objects"].items():
object_grp.require_dataset(
object_key,
object_keypoints.shape,
np.uint16,
data=object_keypoints.astype(np.uint16),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write_jabs_v2, write_jabs_v3, write_jabs_v4, and write_jabs_v5 functions are very similar and contain duplicated code. Consider refactoring these functions to reduce code duplication and improve maintainability.

Comment on lines +43 to +59
def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5):
labels_single = load_jabs(jabs_real_data_v2)
assert isinstance(labels_single, Labels)
save_jabs(labels_single, 2, tmp_path)
labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2))
# Confidence field is not preserved, so just check number of labels
assert len(labels_single) == len(labels_single_written)

labels_multi = load_jabs(jabs_real_data_v5)
assert isinstance(labels_multi, Labels)
save_jabs(labels_multi, 3, tmp_path)
save_jabs(labels_multi, 4, tmp_path)
save_jabs(labels_multi, 5, tmp_path)
labels_v5_written = load_jabs(str(tmp_path / jabs_real_data_v5))
# v5 contains all v4 and v3 data, so only need to check v5
# Confidence field and ordering of identities is not preserved, so just check number of labels
assert len(labels_v5_written) == len(labels_multi)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A new test function, test_jabs, has been added to test the functionality of the load_jabs and save_jabs functions. The tests seem to be well-structured and cover different scenarios (single and multi-mouse poses). However, there is a potential issue with the way files are being saved and loaded for testing.

In lines 46 and 53-55, the save_jabs function is called to save labels to a file at tmp_path. However, the filename argument passed to save_jabs is an integer (2, 3, 4, or 5), not a string representing a filename. This could lead to unexpected behavior or errors when trying to save the file.

In lines 47 and 56, the load_jabs function is called to load labels from a file at str(tmp_path / jabs_real_data_v2) or str(tmp_path / jabs_real_data_v5). However, these paths may not correspond to the files saved by save_jabs if the filenames are not correctly specified.

Please verify that the save_jabs function can accept an integer as a filename, and that the paths used in load_jabs correspond to the files saved by save_jabs.

-    save_jabs(labels_single, 2, tmp_path)
+    save_jabs(labels_single, "file2", tmp_path)

-    save_jabs(labels_multi, 3, tmp_path)
+    save_jabs(labels_multi, "file3", tmp_path)

-    save_jabs(labels_multi, 4, tmp_path)
+    save_jabs(labels_multi, "file4", tmp_path)

-    save_jabs(labels_multi, 5, tmp_path)
+    save_jabs(labels_multi, "file5", tmp_path)

-    labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2))
+    labels_single_written = load_jabs(str(tmp_path / "file2"))

-    labels_v5_written = load_jabs(str(tmp_path / jabs_real_data_v5))
+    labels_v5_written = load_jabs(str(tmp_path / "file5"))

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review Status

Actionable comments generated: 3

Commits Files that changed from the base of the PR and between 31ec9e3 and 3fea3bd.
Files selected for processing (1)
  • sleap_io/io/jabs.py (1 hunks)
Additional comments (Suppressed): 10
sleap_io/io/jabs.py (10)
  • 1-66: The module starts by defining the default keypoints, edges, symmetries, and skeleton for JABS files. This is a good practice as it provides a clear understanding of the expected structure of the data and allows for easy modification if the structure changes in the future.

  • 91-94: Good error handling for file access issues. The function checks if the file exists and if it can be accessed before attempting to read it.

  • 172-184: The make_simple_skeleton function creates a Skeleton with a requested number of nodes attached in a line. This could be useful for creating custom skeletons for different types of data.

  • 187-220: The prediction_to_instance function converts prediction data into an Instance. It includes a check to ensure that the number of keypoints in the skeleton matches the number of keypoints in the data, which is a good practice for ensuring data consistency.

  • 222-237: The get_max_ids_in_video function determines the maximum number of identities that exist at the same time in a video. This could be useful for allocating resources or setting parameters based on the complexity of the video.

  • 350-386: The tracklets_to_v3 function changes identity tracklets to the v3 format specifications. It includes a check to ensure that the input tracklet matrix is two-dimensional, which is a good practice for ensuring data consistency.

  • 389-415: The write_jabs_v2 function writes JABS pose file v2 data to file. It checks that the data being written is single mouse data, which is a good practice for ensuring data consistency.

  • 417-466: The write_jabs_v3 function writes JABS pose file v3 data to file. It uses the tracklets_to_v3 function to convert the identity tracklets to the v3 format before writing them to file.

  • 467-514: The write_jabs_v4 function writes JABS pose file v4 data to file. It extends the write_jabs_v3 function and adds new fields for v4.

  • 516-538: The write_jabs_v5 function writes JABS pose file v5 data to file. It extends the write_jabs_v4 function and adds support for static objects.

Comment on lines +68 to +170
"""Read JABS style pose from a file and return a `Labels` object.

TODO: Attributes are ignored, including px_to_cm field.
TODO: Segmentation data ignored in v6, but will read in pose.
TODO: Lixit static objects currently stored as n_lixit,2 (eg 1 object). Should be converted to multiple objects

Args:
labels_path: Path to the JABS pose file.
skeleton: An optional `Skeleton` object. Defaults to JABS pose version 2-6.

Returns:
Parsed labels as a `Labels` instance.
"""
frames: List[LabeledFrame] = []
# Video name is the pose file minus the suffix
video_name = re.sub(r"(_pose_est_v[2-6])?\.h5", ".avi", labels_path)
if not skeleton:
skeleton = JABS_DEFAULT_SKELETON
tracks = {}

if not os.access(labels_path, os.F_OK):
raise FileNotFoundError(f"{labels_path} doesn't exist.")
if not os.access(labels_path, os.R_OK):
raise PermissionError(f"{labels_path} cannot be accessed.")

with h5py.File(labels_path, "r") as pose_file:
num_frames = pose_file["poseest/points"].shape[0]
try:
pose_version = pose_file["poseest"].attrs["version"][0]
except (KeyError, IndexError):
pose_version = 2
data_shape = pose_file["poseest/points"].shape
assert (
len(data_shape) == 3
), f"Pose version not present and shape does not match single mouse: shape of {data_shape} for {labels_path}"
if pose_version == 2:
tracks[1] = Track("1")
# Change field name for newer pose formats
if pose_version == 3:
id_key = "instance_track_id"
elif pose_version > 3:
id_key = "instance_embed_id"
max_ids = pose_file["poseest/points"].shape[1]

for frame_idx in range(num_frames):
instances = []
pose_data = pose_file["poseest/points"][frame_idx, ...]
# JABS stores y,x for poses
pose_data = np.flip(pose_data, axis=-1)
pose_conf = pose_file["poseest/confidence"][frame_idx, ...]
# single animal case
if pose_version == 2:
new_instance = prediction_to_instance(
pose_data, pose_conf, skeleton, tracks[1]
)
instances.append(new_instance)
# multi-animal case
if pose_version > 2:
pose_ids = pose_file["poseest/" + id_key][frame_idx, ...]
# pose_v3 uses another field to describe the number of valid poses
if pose_version == 3:
max_ids = pose_file["poseest/instance_count"][frame_idx]
for cur_id in range(max_ids):
# v4+ uses reserved values for invalid/unused poses
# Note: ignores 'poseest/id_mask' to keep predictions that were not assigned an id
if pose_version > 3 and pose_ids[cur_id] <= 0:
continue
if pose_ids[cur_id] not in tracks.keys():
tracks[pose_ids[cur_id]] = Track(str(pose_ids[cur_id]))
new_instance = prediction_to_instance(
pose_data[cur_id],
pose_conf[cur_id],
skeleton,
tracks[pose_ids[cur_id]],
)
if new_instance:
instances.append(new_instance)
# Static objects
if (
frame_idx == 0
and pose_version >= 5
and "static_objects" in pose_file.keys()
):
present_objects = pose_file["static_objects"].keys()
for cur_object in present_objects:
object_keypoints = pose_file["static_objects/" + cur_object][:]
object_skeleton = make_simple_skeleton(
cur_object, object_keypoints.shape[0]
)
new_instance = prediction_to_instance(
object_keypoints,
np.ones(object_keypoints.shape[:-1]),
object_skeleton,
)
if new_instance:
instances.append(new_instance)
frame_label = LabeledFrame(Video(video_name), frame_idx, instances)
frames.append(frame_label)
return Labels(frames)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The read_labels function reads JABS style pose from a file and returns a Labels object. It handles both single animal and multi-animal cases, as well as static objects. However, there are several TODO comments indicating areas for future improvement, such as handling attributes, segmentation data, and lixit static objects. These should be addressed to ensure the function fully supports all aspects of the JABS format.

Comment on lines +240 to +314
def convert_labels(all_labels: Labels, video: Video) -> dict:
"""Convert a `Labels` object into JABS-formatted annotations.

Args:
all_labels: SLEAP `Labels` to be converted to JABS format.
video: name of video to be converted

Returns:
Dictionary of JABS data of the `Labels` data.
"""
labels = all_labels.find(video=video)

# Determine shape of output
# Low estimate of last frame labeled
num_frames = max([x.frame_idx for x in labels]) + 1
# If there is metadata available for the video, use that
if video.shape:
num_frames = max(num_frames, video.shape[0])
num_keypoints = [len(x.nodes) for x in all_labels.skeletons if x.name == "Mouse"][0]
num_mice = get_max_ids_in_video(labels, key="Mouse")
# Note that this 1-indexes identities
track_2_idx = {
key: val + 1
for key, val in zip(all_labels.tracks, range(len(all_labels.tracks)))
}
last_unassigned_id = num_mice

keypoint_mat = np.zeros([num_frames, num_mice, num_keypoints, 2], dtype=np.uint16)
confidence_mat = np.zeros([num_frames, num_mice, num_keypoints], dtype=np.float32)
identity_mat = np.zeros([num_frames, num_mice], dtype=np.uint32)
instance_vector = np.zeros([num_frames], dtype=np.uint8)
static_objects = {}

# Populate the matrices with data
for label in labels:
assigned_instances = 0
for instance_idx, instance in enumerate(label.instances):
# Static objects just get added to the object dict
# This will clobber data if more than one frame is annotated
if instance.skeleton.name != "Mouse":
static_objects[instance.skeleton.name] = instance.numpy()
continue
pose = instance.numpy()
if pose.shape[0] != len(JABS_DEFAULT_KEYPOINTS):
warnings.warn(
f"JABS format only supports 12 keypoints for mice. Skipping storage of instance on frame {label.frame_idx} with {len(instance.points)} keypoints."
)
continue
missing_points = np.isnan(pose[:, 0])
pose[np.isnan(pose)] = 0
# JABS stores y,x for poses
pose = np.flip(pose.astype(np.uint16), axis=-1)
keypoint_mat[label.frame_idx, instance_idx, :, :] = pose
confidence_mat[label.frame_idx, instance_idx, ~missing_points] = 1.0
if instance.track:
identity_mat[label.frame_idx, instance_idx] = track_2_idx[
instance.track
]
else:
warnings.warn(
f"Pose with unassigned track found on {label.video.filename} frame {label.frame_idx} instance {instance_idx}. Assigning ID {last_unassigned_id}."
)
identity_mat[label.frame_idx, instance_idx] = last_unassigned_id
last_unassigned_id += 1
assigned_instances += 1
instance_vector[label.frame_idx] = assigned_instances

# Return the data as a dict
return {
"keypoints": keypoint_mat.astype(np.uint16),
"confidence": confidence_mat.astype(np.float32),
"identity": identity_mat.astype(np.uint32),
"num_identities": instance_vector.astype(np.uint16),
"static_objects": static_objects,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convert_labels function converts a Labels object into JABS-formatted annotations. It handles both single mouse and multi-mouse cases, as well as static objects. However, there is a warning message indicating that JABS format only supports 12 keypoints for mice and instances with more keypoints will be skipped. This could potentially lead to data loss if the input Labels object contains instances with more than 12 keypoints.

Comment on lines +317 to +348
def write_labels(labels: Labels, pose_version: int, root_folder: str):
"""Convert and save a SLEAP `Labels` object to a JABS pose file.

Only supports pose version 2 (single mouse) and 3-5 (multi mouse).

Args:
labels: SLEAP `Labels` to be converted to JABS pose format.
pose_version: JABS pose version to use when writing data.
root_folder: Root folder where the jabs files should be written
"""
for video in labels.videos:
converted_labels = convert_labels(labels, video)
out_filename = (
os.path.splitext(video.filename)[0] + f"_pose_est_v{pose_version}.h5"
)
if root_folder:
out_filename = os.path.join(root_folder, out_filename)
os.makedirs(os.path.dirname(out_filename), exist_ok=True)
if os.path.exists(out_filename):
warnings.warn(f"Skipping {out_filename} because it already exists.")
continue
if pose_version == 2:
write_jabs_v2(converted_labels, out_filename)
elif pose_version == 3:
write_jabs_v3(converted_labels, out_filename)
elif pose_version == 4:
write_jabs_v4(converted_labels, out_filename)
elif pose_version == 5:
write_jabs_v5(converted_labels, out_filename)
else:
raise NotImplementedError(f"Pose format {pose_version} not supported.")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The write_labels function converts and saves a SLEAP Labels object to a JABS pose file. It supports multiple versions of the JABS pose format, which is a good practice for ensuring compatibility with different versions of the format. However, it raises a NotImplementedError if an unsupported pose version is provided. It would be better to provide a more informative error message indicating which versions are supported.

-            raise NotImplementedError(f"Pose format {pose_version} not supported.")
+            raise ValueError(f"Unsupported pose format {pose_version}. Supported formats are 2, 3, 4, and 5.")

@SkepticRaven
Copy link
Contributor Author

@talmo I think that should be it!
Also amusing that the bot updates it poem every review.

@talmo talmo merged commit 9a04f7f into talmolab:main Sep 20, 2023
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants