Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance Device and Precision Handling, and Improve Error Messages in DepthPro Model #35

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/depth_pro/depth_pro.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ def create_backbone_model(

def create_model_and_transforms(
config: DepthProConfig = DEFAULT_MONODEPTH_CONFIG_DICT,
device: torch.device = torch.device("cpu"),
precision: torch.dtype = torch.float32,
device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
precision: torch.dtype = torch.float16 if device.type == 'cuda' else torch.float32,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This use of device at this point doesn't seem to work for me.
I have to do this precision = torch.float16 if device.type == 'cuda' else torch.float32 in the body of the function

Copy link

@carlos-bg carlos-bg Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When running depth-pro-run -i PATH_TO_MY_IMG_FILE, I also get a problem:

Traceback (most recent call last):
  File "miniconda3/envs/depth-pro/bin/depth-pro-run", line 8, in <module>
    sys.exit(run_main())
  File "ml-depth-pro/src/depth_pro/cli/run.py", line 150, in main
    run(parser.parse_args())
  File "ml-depth-pro/src/depth_pro/cli/run.py", line 68, in run
    prediction = model.infer(transform(image), f_px=f_px)
  File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 95, in __call__
    img = t(img)
  File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 198, in forward
    return F.convert_image_dtype(image, self.dtype)
  File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 243, in convert_image_dtype
    return F_t.convert_image_dtype(image, dtype)
  File "miniconda3/envs/depth-pro/lib/python3.9/site-packages/torchvision/transforms/_functional_tensor.py", line 73, in convert_image_dtype
    if torch.tensor(0, dtype=dtype).is_floating_point():
TypeError: tensor(): argument 'dtype' must be torch.dtype, not tuple

) -> Tuple[DepthPro, Compose]:
"""Create a DepthPro model and load weights from `config.checkpoint_uri`.

Expand Down Expand Up @@ -146,7 +146,8 @@ def create_model_and_transforms(
# which we would not use. We only use the encoding.
missing_keys = [key for key in missing_keys if "fc_norm" not in key]
if len(missing_keys) != 0:
raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}")
raise KeyError(f"Keys are missing when loading monodepth: {missing_keys}. "
f"Ensure the model checkpoint is compatible with the architecture or update the state dict.")

return model, transform

Expand Down