Skip to content

Commit

Permalink
Convert from Depth Pro default 1536x1536 implementation to 1024x1024 …
Browse files Browse the repository at this point in the history
…tensor CoreML programs
  • Loading branch information
harism committed Oct 16, 2024
1 parent b2cd0d5 commit 6a19ba7
Show file tree
Hide file tree
Showing 4 changed files with 178 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

checkpoints
out

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This fork adds `convert_to_coreml.py` script to convert original Depth Pro model to CoreML programs

## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second

This software project accompanies the research paper:
Expand Down
166 changes: 166 additions & 0 deletions convert_to_coreml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import coremltools as ct
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from matplotlib import pyplot as plt
from typing import Tuple

from src.depth_pro.depth_pro import (
create_model_and_transforms,
create_backbone_model,
DEFAULT_MONODEPTH_CONFIG_DICT
)
from src.depth_pro.network.fov import FOVNetwork
from src.depth_pro.network.vit import resize_vit, resize_patch_embed
from src.depth_pro.utils import load_rgb

from torchvision.transforms import (
Compose,
ConvertImageDtype,
Lambda,
Normalize,
ToTensor
)

class DepthProRun(nn.Module):
def __init__(self, transform: nn.Module, encoder: nn.Module, decoder: nn.Module, depth: nn.Module):
super().__init__()
self.transform = transform
self.encoder = encoder
self.decoder = decoder
self.depth = depth

def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[0] == 3:
x = x.unsqueeze(0)
image = self.transform(x)
encodings = self.encoder(image)
features, features_0 = self.decoder(encodings)
depth = self.depth([image, features, features_0])
return depth

class Depth(nn.Module):
def __init__(self, head: nn.Module, fov: nn.Module):
super(Depth, self).__init__()
self.head = head
self.fov = fov

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = inputs[0]
features = inputs[1]
features_0 = inputs[2]
_, _, H, W = x.shape
if H != 1536 or W != 1536:
x = nn.functional.interpolate(
x,
size=(1536, 1536),
mode="bilinear",
align_corners=False,
)
features_0 = nn.functional.interpolate(
features_0,
size=(48, 48),
mode="bilinear",
align_corners=False,
)
canonical_inverse_depth = self.head(features)
fov_deg = self.fov.forward(x, features_0.detach())
f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0)
inverse_depth = canonical_inverse_depth * f_px
depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4)
return depth

class Interpolate(nn.Module):
def __init__(self, size, mode):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
return x

def save_mlpackage(G, shapes, name):
G.eval()
G_inputs = []
convert_inputs = []
for shape in shapes:
G_inputs.append(torch.randn(shape))
convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float16))
G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs])
G_model = ct.convert(
G_trace,
inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs],
minimum_deployment_target=ct.target.macOS15,
compute_precision=ct.precision.FLOAT16,
compute_units=ct.ComputeUnit.CPU_AND_NE
)
G_model.save("out/" + name + ".mlpackage")

def create_scaled_model() -> Tuple[nn.Module, nn.Module, nn.Module]:
# from run.py
model, _ = create_model_and_transforms(
device=torch.device("cpu"),
precision=torch.float32,
)

new_img_size = (256, 256)
# resize to 256x4 = 1024x1024 input image
model.encoder.patch_encoder = resize_patch_embed(model.encoder.patch_encoder)
model.encoder.patch_encoder = resize_vit(model.encoder.patch_encoder, img_size=new_img_size)
model.encoder.image_encoder = resize_patch_embed(model.encoder.image_encoder)
model.encoder.image_encoder = resize_vit(model.encoder.image_encoder, img_size=new_img_size)
model.encoder.out_size = int(
model.encoder.patch_encoder.patch_embed.img_size[0] // model.encoder.patch_encoder.patch_embed.patch_size[0]
)

# from depth_pro.py
transform = nn.Sequential(
#[
#ToTensor(),
#Lambda(lambda x: x.to(device)),
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
Interpolate(
size=(model.img_size, model.img_size),
mode="bilinear"
),
ConvertImageDtype(torch.float32),
#]
)

depth = Depth(model.head, model.fov)

return transform, model, depth

def load_and_show_example(transform: nn.Module, model: nn.Module, depth: nn.Module):
image, _, _ = load_rgb("data/example.jpg")
depth_pro_run = DepthProRun(transform, model.encoder, model.decoder, depth)

depth_pro = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), depth_pro_run])
depth_map = depth_pro(image).detach().cpu().numpy().squeeze()

plt.ion()
fig = plt.figure()
ax_rgb = fig.add_subplot(121)
ax_disp = fig.add_subplot(122)
ax_rgb.imshow(image)
ax_disp.imshow(depth_map, cmap="turbo")
fig.canvas.draw()
fig.canvas.flush_events()
plt.show(block=True)

def save_coreml_packages(transform: nn.Module, model: nn.Module, depth: nn.Module):
save_mlpackage(transform, [[1, 3, 1024, 1024]], "DepthPro_transform")
save_mlpackage(model.encoder, [[1, 3, 1024, 1024]], "DepthPro_encoder")
save_mlpackage(model.decoder, [[1, 256, 512, 512], [1, 256, 256, 256], [1, 512, 128, 128], [1, 1024, 64, 64], [1, 1024, 32, 32]], "DepthPro_decoder")
save_mlpackage(depth, [[1, 3, 1024, 1024], [1, 256, 512, 512], [1, 256, 32, 32]], "DepthPro_depth")

if __name__ == "__main__":
transform, model, depth = create_scaled_model()
load_and_show_example(transform, model, depth)
save_coreml_packages(transform, model, depth)
12 changes: 6 additions & 6 deletions src/depth_pro/network/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _create_pyramid(

def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
"""Split the input into small patches with sliding window."""
patch_size = 384
patch_size = 256
patch_stride = int(patch_size * (1 - overlap_ratio))

image_size = x.shape[-1]
Expand Down Expand Up @@ -276,7 +276,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
self.out_size,
)
x_latent0_features = self.merge(
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2
)

x_latent1_encodings = self.reshape_feature(
Expand All @@ -285,21 +285,21 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
self.out_size,
)
x_latent1_features = self.merge(
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2
)

# Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
x0_encodings, x1_encodings, x2_encodings = torch.split(
x_pyramid_encodings,
[len(x0_patches), len(x1_patches), len(x2_patches)],
[x0_patches.shape[0], x1_patches.shape[0], x2_patches.shape[0]],
dim=0,
)

# 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=2)

# 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=4)

# 24x24 feature maps.
x2_features = x2_encodings
Expand Down

0 comments on commit 6a19ba7

Please sign in to comment.