-
Notifications
You must be signed in to change notification settings - Fork 273
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convert from Depth Pro default 1536x1536 implementation to 1024x1024 …
…tensor CoreML programs
- Loading branch information
Showing
4 changed files
with
187 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
|
||
checkpoints | ||
out | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import coremltools as ct | ||
import logging | ||
import math | ||
import numpy as np | ||
import torch | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
|
||
from matplotlib import pyplot as plt | ||
from typing import Tuple | ||
|
||
from src.depth_pro.depth_pro import ( | ||
create_model_and_transforms, | ||
create_backbone_model, | ||
DEFAULT_MONODEPTH_CONFIG_DICT | ||
) | ||
from src.depth_pro.network.fov import FOVNetwork | ||
from src.depth_pro.network.vit import resize_vit, resize_patch_embed | ||
from src.depth_pro.utils import load_rgb | ||
|
||
from torchvision.transforms import ( | ||
Compose, | ||
ConvertImageDtype, | ||
Lambda, | ||
Normalize, | ||
ToTensor | ||
) | ||
|
||
class DepthProRun(nn.Module): | ||
def __init__(self, transform: nn.Module, encoder: nn.Module, decoder: nn.Module, depth: nn.Module): | ||
super().__init__() | ||
self.transform = transform | ||
self.encoder = encoder | ||
self.decoder = decoder | ||
self.depth = depth | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
if x.shape[0] == 3: | ||
x = x.unsqueeze(0) | ||
image = self.transform(x) | ||
encodings = self.encoder(image) | ||
features, features_0 = self.decoder(encodings) | ||
depth = self.depth([image, features, features_0]) | ||
return depth | ||
|
||
class Depth(nn.Module): | ||
def __init__(self, head: nn.Module, fov: nn.Module): | ||
super(Depth, self).__init__() | ||
self.head = head | ||
self.fov = fov | ||
|
||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
x = inputs[0] | ||
features = inputs[1] | ||
features_0 = inputs[2] | ||
_, _, H, W = x.shape | ||
# using default size 1536 until fov_encoder resizing succeeds | ||
# 1024 is the expected size to compare against then | ||
if H != 1536 or W != 1536: | ||
x = nn.functional.interpolate( | ||
x, | ||
size=(1536, 1536), | ||
mode="bilinear", | ||
align_corners=False, | ||
) | ||
# this is needed until resizing fov_encoder succeeds | ||
# the surrent resized size (32, 32) is correct here then | ||
features_0 = nn.functional.interpolate( | ||
features_0, | ||
size=(48, 48), | ||
mode="bilinear", | ||
align_corners=False, | ||
) | ||
canonical_inverse_depth = self.head(features) | ||
fov_deg = self.fov.forward(x, features_0.detach()) | ||
f_px = 0.5 * torch.tan(math.pi * fov_deg.to(torch.float) / 360.0) | ||
inverse_depth = canonical_inverse_depth * f_px | ||
depth = 1.0 / inverse_depth.clamp(min=1e-4, max=1e4) | ||
return depth | ||
|
||
class Interpolate(nn.Module): | ||
def __init__(self, size, mode): | ||
super(Interpolate, self).__init__() | ||
self.interp = nn.functional.interpolate | ||
self.size = size | ||
self.mode = mode | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.interp(x, size=self.size, mode=self.mode, align_corners=False) | ||
return x | ||
|
||
def save_mlpackage(G, shapes, name): | ||
G.eval() | ||
G_inputs = [] | ||
convert_inputs = [] | ||
for shape in shapes: | ||
G_inputs.append(torch.randn(shape)) | ||
convert_inputs.append(ct.TensorType(shape=shape, dtype=np.float16)) | ||
G_trace = torch.jit.trace(G, G_inputs if len(G_inputs) == 1 else [G_inputs]) | ||
G_model = ct.convert( | ||
G_trace, | ||
inputs=convert_inputs if len(convert_inputs) <= 1 else [convert_inputs], | ||
minimum_deployment_target=ct.target.macOS15, | ||
compute_precision=ct.precision.FLOAT16, | ||
compute_units=ct.ComputeUnit.CPU_AND_NE | ||
) | ||
G_model.save("out/" + name + ".mlpackage") | ||
|
||
def create_scaled_model() -> Tuple[nn.Module, nn.Module, nn.Module]: | ||
# from run.py | ||
model, _ = create_model_and_transforms( | ||
device=torch.device("cpu"), | ||
precision=torch.float32, | ||
) | ||
|
||
new_img_size = (256, 256) | ||
# resize to 256x4 = 1024x1024 input image | ||
model.encoder.patch_encoder = resize_patch_embed(model.encoder.patch_encoder) | ||
model.encoder.patch_encoder = resize_vit(model.encoder.patch_encoder, img_size=new_img_size) | ||
model.encoder.image_encoder = resize_patch_embed(model.encoder.image_encoder) | ||
model.encoder.image_encoder = resize_vit(model.encoder.image_encoder, img_size=new_img_size) | ||
model.encoder.out_size = int( | ||
model.encoder.patch_encoder.patch_embed.img_size[0] // model.encoder.patch_encoder.patch_embed.patch_size[0] | ||
) | ||
|
||
# this is still under works to resize fov_encoder to 1024x1024 size too | ||
# fov_encoder, _ = create_backbone_model(preset = DEFAULT_MONODEPTH_CONFIG_DICT.fov_encoder_preset) | ||
# fov_encoder = resize_patch_embed(fov_encoder) | ||
# fov_encoder = resize_vit(fov_encoder, img_size=new_img_size) | ||
# model.fov = FOVNetwork(num_features=model.decoder.dim_decoder, fov_encoder=fov_encoder) | ||
|
||
# from depth_pro.py | ||
transform = nn.Sequential( | ||
#[ | ||
#ToTensor(), | ||
#Lambda(lambda x: x.to(device)), | ||
Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | ||
Interpolate( | ||
size=(model.img_size, model.img_size), | ||
mode="bilinear" | ||
), | ||
ConvertImageDtype(torch.float32), | ||
#] | ||
) | ||
|
||
depth = Depth(model.head, model.fov) | ||
return transform, model, depth | ||
|
||
def load_and_show_example(transform: nn.Module, model: nn.Module, depth: nn.Module): | ||
image, _, _ = load_rgb("data/example.jpg") | ||
depth_pro_run = DepthProRun(transform, model.encoder, model.decoder, depth) | ||
|
||
depth_pro = Compose([ToTensor(), Lambda(lambda x: x.to(torch.device("cpu"))), depth_pro_run]) | ||
depth_map = depth_pro(image).detach().cpu().numpy().squeeze() | ||
|
||
plt.ion() | ||
fig = plt.figure() | ||
ax_rgb = fig.add_subplot(121) | ||
ax_disp = fig.add_subplot(122) | ||
ax_rgb.imshow(image) | ||
ax_disp.imshow(depth_map, cmap="turbo") | ||
fig.canvas.draw() | ||
fig.canvas.flush_events() | ||
plt.show(block=True) | ||
|
||
def save_coreml_packages(transform: nn.Module, model: nn.Module, depth: nn.Module): | ||
save_mlpackage(transform, [[1, 3, 1024, 1024]], "DepthPro_transform") | ||
save_mlpackage(model.encoder, [[1, 3, 1024, 1024]], "DepthPro_encoder") | ||
save_mlpackage(model.decoder, [[1, 256, 512, 512], [1, 256, 256, 256], [1, 512, 128, 128], [1, 1024, 64, 64], [1, 1024, 32, 32]], "DepthPro_decoder") | ||
save_mlpackage(depth, [[1, 3, 1024, 1024], [1, 256, 512, 512], [1, 256, 32, 32]], "DepthPro_depth") | ||
|
||
if __name__ == "__main__": | ||
transform, model, depth = create_scaled_model() | ||
load_and_show_example(transform, model, depth) | ||
save_coreml_packages(transform, model, depth) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters