From 7f37d20abde557e0db828c5dbb9466b482544843 Mon Sep 17 00:00:00 2001 From: themattinthehatt Date: Wed, 31 Jul 2024 10:10:26 -0400 Subject: [PATCH] allow user-defined keypoint correspondence in true multiview datasets --- lightning_pose/utils/scripts.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/lightning_pose/utils/scripts.py b/lightning_pose/utils/scripts.py index fecefea6..f74f9c2f 100644 --- a/lightning_pose/utils/scripts.py +++ b/lightning_pose/utils/scripts.py @@ -236,13 +236,19 @@ def get_loss_factories( # assume user has provided a set of columns that are present in each view num_keypoints = cfg.data.num_keypoints num_views = len(cfg.data.view_names) - loss_params_dict["unsupervised"][loss_name][ - "mirrored_column_matches" - ] = [ - (v * num_keypoints - + np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist() - for v in range(num_views) - ] + if isinstance(cfg.data.mirrored_column_matches[0], int): + loss_params_dict["unsupervised"][loss_name][ + "mirrored_column_matches" + ] = [ + (v * num_keypoints + + np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist() + for v in range(num_views) + ] + else: + # allow user to force specific mapping in multiview case + loss_params_dict["unsupervised"][loss_name][ + "mirrored_column_matches" + ] = cfg.data.mirrored_column_matches else: # user must provide all matching columns loss_params_dict["unsupervised"][loss_name][