Skip to content

Commit

Permalink
Convert from Depth Pro default 1536x1536 size to 1024x1024 tensor Cor…
Browse files Browse the repository at this point in the history
…eML programs
  • Loading branch information
harism committed Oct 16, 2024
1 parent b2cd0d5 commit 9ba2417
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 6 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

checkpoints
out

2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This fork adds `convert_to_coreml.py` script to convert original Depth Pro model to CoreML programs

## Depth Pro: Sharp Monocular Metric Depth in Less Than a Second

This software project accompanies the research paper:
Expand Down
119 changes: 119 additions & 0 deletions convert_to_coreml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import coremltools as ct
import logging
import math
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

from src.depth_pro.depth_pro import (
create_model_and_transforms,
create_backbone_model,
DEFAULT_MONODEPTH_CONFIG_DICT
)
from src.depth_pro.network.fov import FOVNetwork
from src.depth_pro.network.vit import resize_vit, resize_patch_embed

from torchvision.transforms import (
ConvertImageDtype,
Normalize,
)

class Depth(nn.Module):
def __init__(self, head: nn.Module, fov: nn.Module):
super(Depth, self).__init__()
self.head = head
self.fov = fov

def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = inputs[0]
features = inputs[1]
features_0 = inputs[2]
_, _, H, W = x.shape
if H != 1536 or W != 1536:
x = nn.functional.interpolate(
x,
size=(1536, 1536),
mode="bilinear",
align_corners=False,
)
features_0 = nn.functional.interpolate(
features_0,
size=(48, 48),
mode="bilinear",
align_corners=False,
)
canonical_inverse_depth = self.head(features)
fov_deg = self.fov.forward(x, features_0.detach())
f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0)
inverse_depth = canonical_inverse_depth * f_px
depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4)
return depth

class Interpolate(nn.Module):
def __init__(self, size, mode):
super(Interpolate, self).__init__()
self.interp = nn.functional.interpolate
self.size = size
self.mode = mode

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.interp(x, size=self.size, mode=self.mode, align_corners=False)
return x

def save_mlpackage(G, shapes, name):
G.eval()
G_inputs = []
convert_inputs = []
for shape in shapes:
G_inputs.append(torch.randn(shape))
convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float16))
G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs])
G_model = ct.convert(
G_trace,
inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs],
minimum_deployment_target=ct.target.macOS15,
compute_precision=ct.precision.FLOAT16,
compute_units=ct.ComputeUnit.CPU_AND_NE
)
G_model.save("out/" + name + ".mlpackage")

def main():
# from run.py
model, _ = create_model_and_transforms(
device=torch.device("cpu"),
precision=torch.float32,
)

new_img_size = (256, 256)
# resize to 256x4 = 1024x1024 input image
model.encoder.patch_encoder = resize_patch_embed(model.encoder.patch_encoder)
model.encoder.patch_encoder = resize_vit(model.encoder.patch_encoder, img_size=new_img_size)
model.encoder.image_encoder = resize_patch_embed(model.encoder.image_encoder)
model.encoder.image_encoder = resize_vit(model.encoder.image_encoder, img_size=new_img_size)
model.encoder.out_size = int(
model.encoder.patch_encoder.patch_embed.img_size[0] // model.encoder.patch_encoder.patch_embed.patch_size[0]
)

# from depth_pro.py
transform = nn.Sequential(
#[
#ToTensor(),
#Lambda(lambda x: x.to(device)),
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
Interpolate(
size=(model.img_size, model.img_size),
mode="bilinear"
),
ConvertImageDtype(torch.float32),
#]
)

depth = Depth(model.head, model.fov)
save_mlpackage(transform, [[1, 3, 1024, 1024]], "DepthPro_transform")
save_mlpackage(model.encoder, [[1, 3, 1024, 1024]], "DepthPro_encoder")
save_mlpackage(model.decoder, [[1, 256, 512, 512], [1, 256, 256, 256], [1, 512, 128, 128], [1, 1024, 64, 64], [1, 1024, 32, 32]], "DepthPro_decoder")
save_mlpackage(depth, [[1, 3, 1024, 1024], [1, 256, 512, 512], [1, 256, 32, 32]], "DepthPro_depth")

if __name__ == "__main__":
main()
12 changes: 6 additions & 6 deletions src/depth_pro/network/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _create_pyramid(

def split(self, x: torch.Tensor, overlap_ratio: float = 0.25) -> torch.Tensor:
"""Split the input into small patches with sliding window."""
patch_size = 384
patch_size = 256
patch_stride = int(patch_size * (1 - overlap_ratio))

image_size = x.shape[-1]
Expand Down Expand Up @@ -276,7 +276,7 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
self.out_size,
)
x_latent0_features = self.merge(
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
x_latent0_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2
)

x_latent1_encodings = self.reshape_feature(
Expand All @@ -285,21 +285,21 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
self.out_size,
)
x_latent1_features = self.merge(
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=3
x_latent1_encodings[: batch_size * 5 * 5], batch_size=batch_size, padding=2
)

# Split the 35 batch size from pyramid encoding back into 5x5+3x3+1x1.
x0_encodings, x1_encodings, x2_encodings = torch.split(
x_pyramid_encodings,
[len(x0_patches), len(x1_patches), len(x2_patches)],
[x0_patches.shape[0], x1_patches.shape[0], x2_patches.shape[0]],
dim=0,
)

# 96x96 feature maps by merging 5x5 @ 24x24 patches with overlaps.
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=3)
x0_features = self.merge(x0_encodings, batch_size=batch_size, padding=2)

# 48x84 feature maps by merging 3x3 @ 24x24 patches with overlaps.
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=6)
x1_features = self.merge(x1_encodings, batch_size=batch_size, padding=4)

# 24x24 feature maps.
x2_features = x2_encodings
Expand Down

0 comments on commit 9ba2417

Please sign in to comment.