From 06dd5a54b3517dd6cb026668d454ce95c92bea5a Mon Sep 17 00:00:00 2001 From: Mayur Kawale <122032765+Mefisto04@users.noreply.github.com> Date: Sat, 12 Oct 2024 00:18:27 +0530 Subject: [PATCH 1/2] Update depth_pro.py --- src/depth_pro/depth_pro.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/depth_pro/depth_pro.py b/src/depth_pro/depth_pro.py index f31b4e1..0a2dd91 100644 --- a/src/depth_pro/depth_pro.py +++ b/src/depth_pro/depth_pro.py @@ -71,8 +71,8 @@ def create_backbone_model( def create_model_and_transforms( config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT, - device: torch.device = torch.device("cpu"), - precision: torch.dtype = torch.float32, + device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + precision: torch.dtype = torch.float16 if device.type == 'cuda' else torch.float32, ) -> Tuple[DepthPro, Compose]: """Create a DepthPro model and load weights from `config.checkpoint_uri`. @@ -146,7 +146,8 @@ def create_model_and_transforms( # which we would not use. We only use the encoding. missing_keys = [key for key in missing_keys if "fc_norm" not in key] if len(missing_keys) != 0: - raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}") + raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}. " + f"Ensure the model checkpoint is compatible with the architecture or update the state dict.") return model, transform From 641e46daf10eaa0459f0632ef413ecb7ebec7694 Mon Sep 17 00:00:00 2001 From: Mayur Kawale <122032765+Mefisto04@users.noreply.github.com> Date: Fri, 25 Oct 2024 00:14:23 +0530 Subject: [PATCH 2/2] Update depth_pro.py --- src/depth_pro/depth_pro.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/depth_pro/depth_pro.py b/src/depth_pro/depth_pro.py index 0a2dd91..36fe5a0 100644 --- a/src/depth_pro/depth_pro.py +++ b/src/depth_pro/depth_pro.py @@ -72,8 +72,10 @@ def create_backbone_model( def create_model_and_transforms( config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT, device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), - precision: torch.dtype = torch.float16 if device.type == 'cuda' else torch.float32, ) -> Tuple[DepthPro, Compose]: + # Determine the precision based on the device here + precision = torch.float16 if device.type == 'cuda' else torch.float32 + """Create a DepthPro model and load weights from `config.checkpoint_uri`. Args: