-
Notifications
You must be signed in to change notification settings - Fork 1
/
run_inference.py
151 lines (136 loc) · 6.66 KB
/
run_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import torch
import os
import cv2
import numpy as np
from shap_e.models.testing_utils import test_model
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
from shap_e.models.download import load_model, load_config
from shap_e.util.notebooks import create_pan_cameras, decode_latent_images, decode_latent_mesh
from shap_e.util.data_util import load_or_create_multimodal_batch
from visualizations.blender_rendering import good_looking_render
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model_path', type=str,
help='path to model', required=True)
parser.add_argument('-d', '--data_path', type=str,
help='path to data', required=True)
parser.add_argument('-o', '--output_dir', type=str,
help='path to output dir', required=True)
parser.add_argument('-p', '--prompt', type=str,
help='text prompt', required=True)
parser.add_argument('--encode_guidance', action='store_true',
help='whether to encode the input data with shap-e encoder')
parser.add_argument('--guidance_scale', type=float, default=7.5,
help='guidance scale')
parser.add_argument('--render_mode', type=str, default='stf',
help='the decoding mode to render')
parser.add_argument('--output_resolution', type=int, default=64,
help='resolution of output images')
parser.add_argument('--render_guidance', action='store_true', default=False,
help='whether to render the guidance shape')
parser.add_argument('--input_guidance_object_path', type=str,
help='path to input guidance object for decoding')
parser.add_argument('--mv_image_size', type=int,
help='size of the images', default=256)
parser.add_argument('--verbose_blender', action='store_true', default=False,
help='if enabled, prints outputs from blender script')
parser.add_argument('--render_blender', action='store_true', default=False,
help='if enabled, prints outputs from blender script')
def prompt2filename(prompt: str):
filename = prompt.replace(" ", "_")
filename = filename.replace("?", "")
filename = filename.replace("!", "")
filename = filename.replace(",", "")
filename = filename.replace('\"', '')
filename = filename.replace('\\', '')
filename = filename.replace('/', '')
return filename
def infer(args, device):
model_path = args.model_path
data_path = args.data_path
output_dir = args.output_dir
prompt = args.prompt
encode_guidance = args.encode_guidance
guidance_scale = args.guidance_scale
render_mode = args.render_mode
assert render_mode in ["stf", "nerf"]
output_resolution = args.output_resolution
render_guidance = args.render_guidance
input_guidance_object_path = args.input_guidance_object_path
mv_image_size = args.mv_image_size
verbose_blender = args.verbose_blender
render_in_blender = args.render_blender
cameras = create_pan_cameras(output_resolution, device)
xm = load_model('transmitter', device=device)
model = load_model('text300M', device=device)
model.wrapped.backbone.make_ctrl_layers()
model.wrapped.set_up_controlnet_cond()
model.load_state_dict(torch.load(model_path))
diffusion = diffusion_from_config(load_config('diffusion'))
# create output dir if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
if encode_guidance:
print("Creating data for encoding from input 3D guidance shape")
batch = load_or_create_multimodal_batch(
device,
model_path=input_guidance_object_path,
mv_light_mode="basic",
mv_image_size=mv_image_size,
cache_dir=os.path.join(data_path, "cached_guidance"),
verbose=verbose_blender,) # this will show Blender output during renders
print("Encoding")
guidance_shape = xm.encoder.encode_to_bottleneck(batch)
else:
guidance_shape = torch.load(data_path)
with torch.no_grad():
filename = prompt2filename(prompt)
prompt = " ".join(filename.split("_"))
output_path = os.path.join(output_dir, filename)
os.makedirs(output_path, exist_ok=True)
# Rendering Model Output
print(f"rendering samples for prompt: {prompt}")
test_model(model=model,
diffusion=diffusion,
xm=xm,
output_folder=output_path,
cond=guidance_shape[0].to(device).detach(),
epoch=0,
prompt=prompt,
device=device,
guidance_scale=guidance_scale,
render_mode=render_mode,
size=output_resolution,
save_mesh=True)
if render_in_blender:
mesh_path = os.path.join(output_path, "output/output.ply")
blender_img_path = os.path.join(output_path, "output/blender_output.png")
good_looking_render(mesh_path, blender_img_path, plastic=False)
if render_guidance:
# Rendering Guidance
print(f"rendering condition latent for prompt: {prompt}")
images = decode_latent_images(xm, guidance_shape, cameras, rendering_mode=render_mode)
cond_path = os.path.join(output_dir, "condition")
os.makedirs(cond_path, exist_ok=True)
torch.save(guidance_shape,os.path.join(cond_path, "condition.pt"))
videowriter = cv2.VideoWriter(os.path.join(cond_path, 'condition.mp4'),
cv2.VideoWriter_fourcc(*'mp4v'), 10, (output_resolution, output_resolution))
for i, image in enumerate(images):
image.save(os.path.join(cond_path, f'{(i):05}.png'))
image = np.array(image)
image = image[:,:,::-1]
videowriter.write(image)
videowriter.release()
t = decode_latent_mesh(xm, guidance_shape).tri_mesh()
with open(os.path.join(cond_path, "condition.obj"), 'w') as f:
t.write_obj(f)
with open(os.path.join(cond_path, "condition.ply"), 'wb') as f:
t.write_ply(f)
if render_in_blender:
mesh_path = os.path.join(cond_path, "condition.ply")
blender_img_path = os.path.join(cond_path, "blender_guidance.png")
good_looking_render(mesh_path, blender_img_path, render_guidance=True, shade_smooth=False, subdivide=False)
torch.cuda.empty_cache()
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args = parser.parse_args()
infer(args, device)