Skip to content

Commit

Permalink
allow user-defined keypoint correspondence in true multiview datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
themattinthehatt committed Jul 31, 2024
1 parent d00ec3b commit 7f37d20
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions lightning_pose/utils/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,19 @@ def get_loss_factories(
# assume user has provided a set of columns that are present in each view
num_keypoints = cfg.data.num_keypoints
num_views = len(cfg.data.view_names)
loss_params_dict["unsupervised"][loss_name][
"mirrored_column_matches"
] = [
(v * num_keypoints
+ np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist()
for v in range(num_views)
]
if isinstance(cfg.data.mirrored_column_matches[0], int):
loss_params_dict["unsupervised"][loss_name][
"mirrored_column_matches"
] = [
(v * num_keypoints
+ np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist()
for v in range(num_views)
]
else:
# allow user to force specific mapping in multiview case
loss_params_dict["unsupervised"][loss_name][
"mirrored_column_matches"
] = cfg.data.mirrored_column_matches
else:
# user must provide all matching columns
loss_params_dict["unsupervised"][loss_name][
Expand Down

0 comments on commit 7f37d20

Please sign in to comment.