From b8dbdc50dd861593efb08f3b8f7da3ae24c68d0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Harri=20Sma=CC=8Att?= Date: Tue, 15 Oct 2024 18:24:59 +0300 Subject: [PATCH] Convert from Depth Pro default 1536x1536 implementation to 1024x1024 tensor CoreML programs --- .gitignore | 4 + README.md | 2 + convert_to_coreml.py | 175 +++++++++++++++++++++++++++++++ src/depth_pro/network/encoder.py | 12 +-- 4 files changed, 187 insertions(+), 6 deletions(-) create mode 100644 .gitignore create mode 100644 convert_to_coreml.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9177316 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ + +checkpoints +out + diff --git a/README.md b/README.md index 6c4ea61..cc6d313 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,5 @@ +# This fork adds `convert_to_coreml.py` script to convert original Depth Pro model to CoreML programs + ## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second This software project accompanies the research paper: diff --git a/convert_to_coreml.py b/convert_to_coreml.py new file mode 100644 index 0000000..941699a --- /dev/null +++ b/convert_to_coreml.py @@ -0,0 +1,175 @@ +import coremltools as ct +import logging +import math +import numpy as np +import torch +import torch.nn as nn +import torch.optim as optim + +from matplotlib import pyplot as plt +from typing import Tuple + +from src.depth_pro.depth_pro import ( + create_model_and_transforms, + create_backbone_model, + DEFAULT_MONODEPTH_CONFIG_DICT +) +from src.depth_pro.network.fov import FOVNetwork +from src.depth_pro.network.vit import resize_vit, resize_patch_embed +from src.depth_pro.utils import load_rgb + +from torchvision.transforms import ( + Compose, + ConvertImageDtype, + Lambda, + Normalize, + ToTensor +) + +class DepthProRun(nn.Module): + def __init__(self, transform: nn.Module, encoder: nn.Module, decoder: nn.Module, depth: nn.Module): + super().__init__() + self.transform = transform + self.encoder = encoder + self.decoder = decoder + self.depth = depth + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[0] == 3: + x = x.unsqueeze(0) + image = self.transform(x) + encodings = self.encoder(image) + features, features_0 = self.decoder(encodings) + depth = self.depth([image, features, features_0]) + return depth + +class Depth(nn.Module): + def __init__(self, head: nn.Module, fov: nn.Module): + super(Depth, self).__init__() + self.head = head + self.fov = fov + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + x = inputs[0] + features = inputs[1] + features_0 = inputs[2] + _, _, H, W = x.shape + # using default size 1536 until fov_encoder resizing succeeds + # 1024 is the expected size to compare against then + if H != 1536 or W != 1536: + x = nn.functional.interpolate( + x, + size=(1536, 1536), + mode="bilinear", + align_corners=False, + ) + # this is needed until resizing fov_encoder succeeds + # the surrent resized size (32, 32) is correct here then + features_0 = nn.functional.interpolate( + features_0, + size=(48, 48), + mode="bilinear", + align_corners=False, + ) + canonical_inverse_depth = self.head(features) + fov_deg = self.fov.forward(x, features_0.detach()) + f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0) + inverse_depth = canonical_inverse_depth * f_px + depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4) + return depth + +class Interpolate(nn.Module): + def __init__(self, size, mode): + super(Interpolate, self).__init__() + self.interp = nn.functional.interpolate + self.size = size + self.mode = mode + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.interp(x, size=self.size, mode=self.mode, align_corners=False) + return x + +def save_mlpackage(G, shapes, name): + G.eval() + G_inputs = [] + convert_inputs = [] + for shape in shapes: + G_inputs.append(torch.randn(shape)) + convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float16)) + G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs]) + G_model = ct.convert( + G_trace, + inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs], + minimum_deployment_target=ct.target.macOS15, + compute_precision=ct.precision.FLOAT16, + compute_units=ct.ComputeUnit.CPU_AND_NE + ) + G_model.save("out/" + name + ".mlpackage") + +def create_scaled_model() -> Tuple[nn.Module, nn.Module, nn.Module]: + # from run.py + model, _ = create_model_and_transforms( + device=torch.device("cpu"), + precision=torch.float32, + ) + + new_img_size = (256, 256) + # resize to 256x4 = 1024x1024 input image + model.encoder.patch_encoder = resize_patch_embed(model.encoder.patch_encoder) + model.encoder.patch_encoder = resize_vit(model.encoder.patch_encoder, img_size=new_img_size) + model.encoder.image_encoder = resize_patch_embed(model.encoder.image_encoder) + model.encoder.image_encoder = resize_vit(model.encoder.image_encoder, img_size=new_img_size) + model.encoder.out_size = int( + model.encoder.patch_encoder.patch_embed.img_size[0] // model.encoder.patch_encoder.patch_embed.patch_size[0] + ) + + # this is still under works to resize fov_encoder to 1024x1024 size too + # fov_encoder, _ = create_backbone_model(preset = DEFAULT_MONODEPTH_CONFIG_DICT.fov_encoder_preset) + # fov_encoder = resize_patch_embed(fov_encoder) + # fov_encoder = resize_vit(fov_encoder, img_size=new_img_size) + # model.fov = FOVNetwork(num_features=model.decoder.dim_decoder, fov_encoder=fov_encoder) + + # from depth_pro.py + transform = nn.Sequential( + #[ + #ToTensor(), + #Lambda(lambda x: x.to(device)), + Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), + Interpolate( + size=(model.img_size, model.img_size), + mode="bilinear" + ), + ConvertImageDtype(torch.float32), + #] + ) + + depth = Depth(model.head, model.fov) + return transform, model, depth + +def load_and_show_example(transform: nn.Module, model: nn.Module, depth: nn.Module): + image, _, _ = load_rgb("data/example.jpg") + depth_pro_run = DepthProRun(transform, model.encoder, model.decoder, depth) + + depth_pro = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), depth_pro_run]) + depth_map = depth_pro(image).detach().cpu().numpy().squeeze() + + plt.ion() + fig = plt.figure() + ax_rgb = fig.add_subplot(121) + ax_disp = fig.add_subplot(122) + ax_rgb.imshow(image) + ax_disp.imshow(depth_map, cmap="turbo") + fig.canvas.draw() + fig.canvas.flush_events() + plt.show(block=True) + +def save_coreml_packages(transform: nn.Module, model: nn.Module, depth: nn.Module): + save_mlpackage(transform, [[1, 3, 1024, 1024]], "DepthPro_transform") + save_mlpackage(model.encoder, [[1, 3, 1024, 1024]], "DepthPro_encoder") + save_mlpackage(model.decoder, [[1, 256, 512, 512], [1, 256, 256, 256], [1, 512, 128, 128], [1, 1024, 64, 64], [1, 1024, 32, 32]], "DepthPro_decoder") + save_mlpackage(depth, [[1, 3, 1024, 1024], [1, 256, 512, 512], [1, 256, 32, 32]], "DepthPro_depth") + +if __name__ == "__main__": + transform, model, depth = create_scaled_model() + load_and_show_example(transform, model, depth) + save_coreml_packages(transform, model, depth) diff --git a/src/depth_pro/network/encoder.py b/src/depth_pro/network/encoder.py index a3a3da1..eea1bbd 100644 --- a/src/depth_pro/network/encoder.py +++ b/src/depth_pro/network/encoder.py @@ -169,7 +169,7 @@ def _create_pyramid( def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor: """Split the input into small patches with sliding window.""" - patch_size = 384 + patch_size = 256 patch_stride = int(patch_size * (1 - overlap_ratio)) image_size = x.shape[-1] @@ -276,7 +276,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: self.out_size, ) x_latent0_features = self.merge( - x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 + x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2 ) x_latent1_encodings = self.reshape_feature( @@ -285,21 +285,21 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]: self.out_size, ) x_latent1_features = self.merge( - x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3 + x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2 ) # Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1. x0_encodings, x1_encodings, x2_encodings = torch.split( x_pyramid_encodings, - [len(x0_patches), len(x1_patches), len(x2_patches)], + [x0_patches.shape[0], x1_patches.shape[0], x2_patches.shape[0]], dim=0, ) # 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps. - x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3) + x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=2) # 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps. - x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6) + x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=4) # 24x24 feature maps. x2_features = x2_encodings