diff --git a/src/depth_pro/depth_pro.py b/src/depth_pro/depth_pro.py index f31b4e1..36fe5a0 100644 --- a/src/depth_pro/depth_pro.py +++ b/src/depth_pro/depth_pro.py @@ -71,9 +71,11 @@ def create_backbone_model( def create_model_and_transforms( config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT, - device: torch.device = torch.device("cpu"), - precision: torch.dtype = torch.float32, + device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), ) -> Tuple[DepthPro, Compose]: + # Determine the precision based on the device here + precision = torch.float16 if device.type == 'cuda' else torch.float32 + """Create a DepthPro model and load weights from `config.checkpoint_uri`. Args: @@ -146,7 +148,8 @@ def create_model_and_transforms( # which we would not use. We only use the encoding. missing_keys = [key for key in missing_keys if "fc_norm" not in key] if len(missing_keys) != 0: - raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}") + raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}. " + f"Ensure the model checkpoint is compatible with the architecture or update the state dict.") return model, transform