diff --git a/lightning_pose/utils/scripts.py b/lightning_pose/utils/scripts.py index fecefea6..f74f9c2f 100644 --- a/lightning_pose/utils/scripts.py +++ b/lightning_pose/utils/scripts.py @@ -236,13 +236,19 @@ def get_loss_factories( # assume user has provided a set of columns that are present in each view num_keypoints = cfg.data.num_keypoints num_views = len(cfg.data.view_names) - loss_params_dict["unsupervised"][loss_name][ - "mirrored_column_matches" - ] = [ - (v * num_keypoints - + np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist() - for v in range(num_views) - ] + if isinstance(cfg.data.mirrored_column_matches[0], int): + loss_params_dict["unsupervised"][loss_name][ + "mirrored_column_matches" + ] = [ + (v * num_keypoints + + np.array(cfg.data.mirrored_column_matches, dtype=int)).tolist() + for v in range(num_views) + ] + else: + # allow user to force specific mapping in multiview case + loss_params_dict["unsupervised"][loss_name][ + "mirrored_column_matches" + ] = cfg.data.mirrored_column_matches else: # user must provide all matching columns loss_params_dict["unsupervised"][loss_name][