diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9177316 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ + +checkpoints +out + diff --git a/README.md b/README.md index 6c4ea61..cc6d313 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +# This fork adds `convert_to_coreml.py` script to convert original Depth Pro model to CoreML programs + ## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second This software project accompanies the research paper: diff --git a/convert_to_coreml.py b/convert_to_coreml.py new file mode 100644 index 0000000..9d7ecac --- /dev/null +++ b/convert_to_coreml.py @@ -0,0 +1,349 @@ +import logging +import math +import numpy as np + +import coremltools as ct +from coremltools.converters.mil import register_torch_op +from coremltools.converters.mil.frontend.torch.ops import upsample_bilinear2d +from coremltools.converters.mil.frontend.torch.torch_op_registry import register_torch_op + +import torch +from torch import nn +from torch.nn import functional as F + +from matplotlib import pyplot as plt +from typing import Dict, Tuple + +from src.depth_pro.depth_pro import ( + create_model_and_transforms, + create_backbone_model, + DepthProConfig +) +from src.depth_pro.network.decoder import MultiresConvDecoder +from src.depth_pro.network.encoder import DepthProEncoder +from src.depth_pro.network.fov import FOVNetwork +from src.depth_pro.network.vit import resize_vit, resize_patch_embed +from src.depth_pro.utils import load_rgb + +from torchvision.transforms import ( + Compose, + ConvertImageDtype, + Lambda, + Normalize, + ToTensor +) + +CONFIG_DICT: Dict[str, DepthProConfig] = { + "large_192": DepthProConfig( + patch_encoder_preset="dinov2l16_192", + image_encoder_preset="dinov2l16_192", + checkpoint_uri="./checkpoints/depth_pro.pt", + decoder_features=256, + use_fov_head=True, + fov_encoder_preset="dinov2l16_192", + encoder_scale_size=(192, 192), + head_paddings=[1, 0, 1, 0], + fov_head_paddings=[1, 2, 3, 0], + ), + "large_288": DepthProConfig( + patch_encoder_preset="dinov2l16_288", + image_encoder_preset="dinov2l16_288", + checkpoint_uri="./checkpoints/depth_pro.pt", + decoder_features=256, + use_fov_head=True, + fov_encoder_preset="dinov2l16_288", + encoder_scale_size=(288, 288), + head_paddings=[1, 0, 1, 0], + fov_head_paddings=[1, 1, 2, 0], + ), + "large_384": DepthProConfig( + patch_encoder_preset="dinov2l16_384", + image_encoder_preset="dinov2l16_384", + checkpoint_uri="./checkpoints/depth_pro.pt", + decoder_features=256, + use_fov_head=True, + fov_encoder_preset="dinov2l16_384", + encoder_scale_size=(384, 384), + head_paddings=[1, 0, 1, 0], + fov_head_paddings=[1, 1, 1, 0], + ), +} + +class DepthDecoder(nn.Module): + def __init__(self, head: nn.Module, fov: FOVNetwork, encoder_scale_size: (int, int)): + super(DepthDecoder, self).__init__() + self.head = head + self.fov = fov + self.encoder_scale_size = encoder_scale_size + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + x = inputs[0] + features = inputs[1] + features_0 = inputs[2] + + # execute fov.forward locally with a different scale_factor + # fov_deg = self.fov.forward(x, features_0.detach()) + if hasattr(self.fov, "encoder"): + x = F.interpolate( + x, + size=self.encoder_scale_size, + #scale_factor=self.encoder_scale_factor, + mode="bilinear", + align_corners=False, + ) + x = self.fov.encoder(x)[:, 1:].permute(0, 2, 1) + lowres_feature = self.fov.downsample(features_0.detach()) + x = x.reshape_as(lowres_feature) + lowres_feature + else: + x = features_0.detach() + + fov_deg = self.fov.head(x) + f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0) + + canonical_inverse_depth = self.head(features) + inverse_depth = canonical_inverse_depth * f_px + depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4) + return depth + +class DepthProScaled(nn.Module): + def __init__(self, transform: nn.Module, encoder: DepthProEncoder, decoder: MultiresConvDecoder, depth: DepthDecoder): + super().__init__() + self.transform = transform + self.encoder = encoder + self.decoder = decoder + self.depth = depth + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[0] == 3: + x = x.unsqueeze(0) + image = self.transform(x) + encodings = self.encoder(image) + features, features_0 = self.decoder(encodings) + depth = self.depth([image, features, features_0]) + return depth + +class Interpolate(nn.Module): + def __init__(self, size, mode): + super(Interpolate, self).__init__() + self.size = size + self.mode = mode + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.interpolate(x, size=self.size, mode=self.mode, align_corners=False) + return x + +def create_scaled_model(config: DepthProConfig) -> DepthProScaled: + patch_encoder, patch_encoder_config = create_backbone_model(preset = config.patch_encoder_preset) + image_encoder, _ = create_backbone_model(preset = config.image_encoder_preset) + fov_encoder, _ = create_backbone_model(preset = config.fov_encoder_preset) + # fov_encoder = None + + dims_encoder = patch_encoder_config.encoder_feature_dims + hook_block_ids = patch_encoder_config.encoder_feature_layer_ids + encoder = DepthProEncoder( + dims_encoder=dims_encoder, + patch_encoder=patch_encoder, + image_encoder=image_encoder, + hook_block_ids=hook_block_ids, + decoder_features=config.decoder_features, + ) + + decoder = MultiresConvDecoder( + dims_encoder=[config.decoder_features] + list(encoder.dims_encoder), + dim_decoder=config.decoder_features, + ) + + num_features = config.decoder_features + fov = FOVNetwork(num_features=num_features, fov_encoder=fov_encoder) + # Create FOV head. + fov_head0 = [ + nn.Conv2d( + num_features, num_features // 2, kernel_size=3, stride=2, padding=config.fov_head_paddings[0] + ), # 128 x 24 x 24 + nn.ReLU(True), + ] + fov_head = [ + nn.Conv2d( + num_features // 2, num_features // 4, kernel_size=3, stride=2, padding=config.fov_head_paddings[1] + ), # 64 x 12 x 12 + nn.ReLU(True), + nn.Conv2d( + num_features // 4, num_features // 8, kernel_size=3, stride=2, padding=config.fov_head_paddings[2] + ), # 32 x 6 x 6 + nn.ReLU(True), + nn.Conv2d(num_features // 8, 1, kernel_size=6, stride=1, padding=config.fov_head_paddings[3]), + ] + if fov_encoder is not None: + fov.encoder = nn.Sequential( + fov_encoder, nn.Linear(fov_encoder.embed_dim, num_features // 2) + ) + fov.downsample = nn.Sequential(*fov_head0) + else: + fov_head = fov_head0 + fov_head + fov.head = nn.Sequential(*fov_head) + # fov = None + + last_dims = (32, 1) + dim_decoder = config.decoder_features + head = nn.Sequential( + nn.Conv2d( + dim_decoder, dim_decoder // 2, kernel_size=3, stride=1, padding=config.head_paddings[0] + ), + nn.ConvTranspose2d( + in_channels=dim_decoder // 2, + out_channels=dim_decoder // 2, + kernel_size=2, + stride=2, + padding=config.head_paddings[1], + bias=True, + ), + nn.Conv2d( + dim_decoder // 2, + last_dims[0], + kernel_size=3, + stride=1, + padding=config.head_paddings[2], + ), + nn.ReLU(True), + nn.Conv2d(last_dims[0], last_dims[1], kernel_size=1, stride=1, padding=config.head_paddings[3]), + nn.ReLU(), + ) + + # Set the final convolution layer's bias to be 0. + head[4].bias.data.fill_(0) + + # from depth_pro.py + transform = nn.Sequential( + #[ + #ToTensor(), + #Lambda(lambda x: x.to(device)), + Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + Interpolate( + size=(encoder.img_size, encoder.img_size), + mode="bilinear" + ), + ConvertImageDtype(torch.float32), + #] + ) + + depth = DepthDecoder(head, fov, config.encoder_scale_size) + load_state_dict(depth, config) + + model = DepthProScaled(transform, encoder, decoder, depth) + load_state_dict(model, config) + + return model + +def load_state_dict(model: nn.Module, config: DepthProConfig): + checkpoint_uri = config.checkpoint_uri + state_dict = torch.load(checkpoint_uri, map_location="cpu") + _, _ = model.load_state_dict( + state_dict=state_dict, strict=False + ) + +def load_and_show_examples(models: tuple[DepthProScaled]): + plt.ion() + fig = plt.figure() + ax_rgb = fig.add_subplot(1, 1 + len(models), 1) + + image, _, _ = load_rgb("data/example.jpg") + ax_rgb.imshow(image) + + for index in range(len(models)): + model_run = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), models[index]]) + depth_map = model_run(image).detach().cpu().numpy().squeeze() + + ax_disp = fig.add_subplot(1, 1 + len(models), 2 + index) + ax_disp.imshow(depth_map, cmap="turbo") + + fig.canvas.draw() + fig.canvas.flush_events() + plt.show(block=True) + +def save_coreml_packages(model: DepthProScaled): + transform = nn.Sequential( + #[ + #ToTensor(), + #Lambda(lambda x: x.to(device)), + Normalize([127.5, 127.5, 127.5], [127.5, 127.5, 127.5]), + Interpolate( + size=(model.encoder.img_size, model.encoder.img_size), + mode="bilinear" + ), + ConvertImageDtype(torch.float16), + #] + ) + save_mlpackage(transform, [[1, 3, 1080, 1920]], "DepthPro_transform", True) + save_mlpackage(model.encoder, [[1, 3, 768, 768]], "DepthPro_encoder") + save_mlpackage(model.decoder, [[1, 256, 288, 288], [1, 256, 144, 144], [1, 512, 72, 72], [1, 1024, 24, 24], [1, 1024, 24, 24]], "DepthPro_decoder") + save_mlpackage(model.depth, [[1, 3, 768, 768], [1, 256, 288, 288], [1, 256, 24, 24]], "DepthPro_depth") + save_mlpackage(model.depth.head, [[1, 256, 768, 768]], "DepthPro_head") + +@register_torch_op() +def _upsample_bicubic2d_aa(context, node): + upsample_bilinear2d(context, node) + +# https://github.com/apple/coremltools/pull/2354 CoreMLTools 8.0 fix +from coremltools.converters.mil.frontend.torch.ops import _get_bindings, _get_inputs +from coremltools.converters.mil.frontend.torch.utils import TorchFrontend +from coremltools.converters.mil.mil import Builder as mb +from coremltools.converters.mil.mil.ops.defs._utils import promote_input_dtypes +from coremltools.converters.mil.mil.var import Var +@register_torch_op(torch_alias=["concat"], override=True) +def cat(context, node): + def is_tensor_empty(var: Var) -> bool: + return np.any([size == 0 for size in var.shape]) + + def _parse_positional_args(context, node) -> Tuple[Var]: + inputs = _get_inputs(context, node, min_expected=1) + nargs = len(inputs) + + xs = inputs[0] + # PyTorch can have empty tensor, which is then ignored + # However, CoreML does not allow such empty tensor, so remove them now + if np.any([is_tensor_empty(x) for x in xs]): + filtered_xs = [x for x in xs if not is_tensor_empty(x)] + xs = filtered_xs if len(filtered_xs) > 0 else [xs[0]] + + dim = inputs[1] if nargs > 1 else 0 + + return xs, dim + + def _parse_keyword_args(context, node, dim) -> Var: + # Only torch.export may have kwargs + if context.frontend != TorchFrontend.TORCHEXPORT: + return dim + + dim = _get_kwinputs(context, node, "dim", default=[dim])[0] + return dim + + xs, dim = _parse_positional_args(context, node) + dim = _parse_keyword_args(context, node, dim) + + concat = mb.concat(values=promote_input_dtypes(xs), axis=dim, name=node.name) + context.add(concat) + +def save_mlpackage(G, shapes, name, image_type = False): + G.eval() + G_inputs = [] + convert_inputs = [] + for shape in shapes: + G_inputs.append(torch.randn(shape)) + convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float16) if image_type == False else ct.ImageType(shape=shape, color_layout=ct.colorlayout.RGB)) + G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs]) + G_model = ct.convert( + G_trace, + inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs], + minimum_deployment_target=ct.target.macOS15, + compute_precision=ct.precision.FLOAT16, + compute_units=ct.ComputeUnit.CPU_AND_NE + ) + G_model.save("out/" + name + ".mlpackage") + +if __name__ == "__main__": + model_192 = create_scaled_model(CONFIG_DICT["large_192"]) + model_288 = create_scaled_model(CONFIG_DICT["large_288"]) + model_384 = create_scaled_model(CONFIG_DICT["large_384"]) + load_and_show_examples((model_192, model_288, model_384)) + + # save_coreml_packages(model_192) diff --git a/src/depth_pro/depth_pro.py b/src/depth_pro/depth_pro.py index f31b4e1..ae6c0c4 100644 --- a/src/depth_pro/depth_pro.py +++ b/src/depth_pro/depth_pro.py @@ -35,6 +35,9 @@ class DepthProConfig: fov_encoder_preset: Optional[ViTPreset] = None use_fov_head: bool = True + encoder_scale_size: tuple[int] = () + head_paddings: tuple[int] = () + fov_head_paddings: tuple[int] = () DEFAULT_MONODEPTH_CONFIG_DICT = DepthProConfig( patch_encoder_preset="dinov2l16_384", diff --git a/src/depth_pro/network/decoder.py b/src/depth_pro/network/decoder.py index 770665f..b1256e0 100644 --- a/src/depth_pro/network/decoder.py +++ b/src/depth_pro/network/decoder.py @@ -169,6 +169,10 @@ def forward(self, x0: torch.Tensor, x1: torch.Tensor | None = None) -> torch.Ten if x1 is not None: res = self.resnet1(x1) + _, _, Wx, Hx = x.shape + _, _, Wres, Hres = res.shape + if Wx != Wres or Hx != Hres: + x = nn.functional.interpolate(x, size=(Wres, Hres), mode="bilinear", align_corners=False) x = self.skip_add.add(x, res) x = self.resnet2(x) diff --git a/src/depth_pro/network/encoder.py b/src/depth_pro/network/encoder.py index a3a3da1..09f638b 100644 --- a/src/depth_pro/network/encoder.py +++ b/src/depth_pro/network/encoder.py @@ -169,7 +169,7 @@ def _create_pyramid( def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: """Split the input into small patches with sliding window.""" - patch_size = 384 + patch_size = self.patch_encoder.patch_embed.img_size[0] patch_stride = int(patch_size * (1 - overlap_ratio)) image_size = x.shape[-1] diff --git a/src/depth_pro/network/vit.py b/src/depth_pro/network/vit.py index c6c3768..73036b9 100644 --- a/src/depth_pro/network/vit.py +++ b/src/depth_pro/network/vit.py @@ -48,11 +48,11 @@ def forward_features_eva_fixed(self, x): return x -def resize_vit(model: nn.Module, img_size) -> nn.Module: +def resize_vit(model: nn.Module, img_size, grid_size) -> nn.Module: """Resample the ViT module to the given size.""" patch_size = model.patch_embed.patch_size model.patch_embed.img_size = img_size - grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) + # grid_size = tuple([s // p for s, p in zip(img_size, patch_size)]) model.patch_embed.grid_size = grid_size pos_embed = resample_abs_pos_embed( diff --git a/src/depth_pro/network/vit_factory.py b/src/depth_pro/network/vit_factory.py index 2cd899f..68db62b 100644 --- a/src/depth_pro/network/vit_factory.py +++ b/src/depth_pro/network/vit_factory.py @@ -24,6 +24,8 @@ ViTPreset = Literal[ + "dinov2l16_192", + "dinov2l16_288", "dinov2l16_384", ] @@ -37,6 +39,7 @@ class ViTConfig: img_size: int = 384 patch_size: int = 16 + grid_size: int = 24 # In case we need to rescale the backbone when loading from timm. timm_preset: Optional[str] = None @@ -51,6 +54,30 @@ class ViTConfig: VIT_CONFIG_DICT: Dict[ViTPreset, ViTConfig] = { + "dinov2l16_192": ViTConfig( + in_chans=3, + embed_dim=1024, + encoder_feature_layer_ids=[5, 11, 17, 23], + encoder_feature_dims=[256, 512, 1024, 1024], + img_size=192, + patch_size=16, + grid_size=24, + timm_preset="vit_large_patch14_dinov2", + timm_img_size=518, + timm_patch_size=14, + ), + "dinov2l16_288": ViTConfig( + in_chans=3, + embed_dim=1024, + encoder_feature_layer_ids=[5, 11, 17, 23], + encoder_feature_dims=[256, 512, 1024, 1024], + img_size=288, + patch_size=16, + grid_size=24, + timm_preset="vit_large_patch14_dinov2", + timm_img_size=518, + timm_patch_size=14, + ), "dinov2l16_384": ViTConfig( in_chans=3, embed_dim=1024, @@ -58,6 +85,7 @@ class ViTConfig: encoder_feature_dims=[256, 512, 1024, 1024], img_size=384, patch_size=16, + grid_size=24, timm_preset="vit_large_patch14_dinov2", timm_img_size=518, timm_patch_size=14, @@ -107,7 +135,7 @@ def create_vit( if config.patch_size != config.timm_patch_size: model.model = resize_patch_embed(model.model, new_patch_size=patch_size) if config.img_size != config.timm_img_size: - model.model = resize_vit(model.model, img_size=img_size) + model.model = resize_vit(model.model, img_size=img_size, grid_size=(config.grid_size, config.grid_size)) if checkpoint_uri is not None: state_dict = torch.load(checkpoint_uri, map_location="cpu")