forked from modelscope/facechain
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
383 lines (313 loc) · 16.4 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
# Copyright (c) Alibaba, Inc. and its affiliates.
import enum
import os
import shutil
import time
from concurrent.futures import ThreadPoolExecutor
from torch import multiprocessing
import cv2
import gradio as gr
import numpy as np
import torch
from modelscope import snapshot_download
from facechain.inference import GenPortrait
from facechain.train_text_to_image_lora import prepare_dataset, data_process_fn
from facechain.constants import neg_prompt, pos_prompt_with_cloth, pos_prompt_with_style, styles, cloth_prompt, pose_models, pose_examples
training_threadpool = ThreadPoolExecutor(max_workers=1)
inference_threadpool = ThreadPoolExecutor(max_workers=5)
training_done_count = 0
inference_done_count = 0
HOT_MODELS = [
"\N{fire}数字身份(Digital Identity)",
]
class UploadTarget(enum.Enum):
PERSONAL_PROFILE = 'Personal Profile'
LORA_LIaBRARY = 'LoRA Library'
def update_cloth(style_index):
prompts = []
if style_index == 0:
example_prompt = generate_pos_prompt(styles[style_index]['name'],
cloth_prompt[0]['prompt'])
for prompt in cloth_prompt:
prompts.append(prompt['name'])
else:
example_prompt = generate_pos_prompt(styles[style_index]['name'],
styles[style_index]['add_prompt_style'])
prompts.append(styles[style_index]['cloth_name'])
return gr.Radio.update(choices=prompts, value=prompts[0]), gr.Textbox.update(value=example_prompt)
def update_prompt(style_index, cloth_index):
if style_index == 0:
pos_prompt = generate_pos_prompt(styles[style_index]['name'],
cloth_prompt[cloth_index]['prompt'])
else:
pos_prompt = generate_pos_prompt(styles[style_index]['name'],
styles[style_index]['add_prompt_style'])
return gr.Textbox.update(value=pos_prompt)
def update_pose_model(pose_image):
if pose_image is None:
return gr.Radio.update(value=pose_models[0]['name'])
else:
return gr.Radio.update(value=pose_models[1]['name'])
def concatenate_images(images):
heights = [img.shape[0] for img in images]
max_width = sum([img.shape[1] for img in images])
concatenated_image = np.zeros((max(heights), max_width, 3), dtype=np.uint8)
x_offset = 0
for img in images:
concatenated_image[0:img.shape[0], x_offset:x_offset + img.shape[1], :] = img
x_offset += img.shape[1]
return concatenated_image
def train_lora_fn(foundation_model_path=None, revision=None, output_img_dir=None, work_dir=None):
os.system(
f'PYTHONPATH=. accelerate launch facechain/train_text_to_image_lora.py --pretrained_model_name_or_path={foundation_model_path} '
f'--revision={revision} --sub_path="film/film" '
f'--output_dataset_name={output_img_dir} --caption_column="text" --resolution=512 '
f'--random_flip --train_batch_size=1 --num_train_epochs=200 --checkpointing_steps=5000 '
f'--learning_rate=1.5e-04 --lr_scheduler="cosine" --lr_warmup_steps=0 --seed=42 --output_dir={work_dir} '
f'--lora_r=4 --lora_alpha=32 --lora_text_encoder_r=32 --lora_text_encoder_alpha=32 --resume_from_checkpoint="fromfacecommon"')
def generate_pos_prompt(style_model, prompt_cloth):
if style_model == styles[0]['name'] or style_model is None:
pos_prompt = pos_prompt_with_cloth.format(prompt_cloth)
else:
matched = list(filter(lambda style: style_model == style['name'], styles))
if len(matched) == 0:
raise ValueError(f'styles not found: {style_model}')
matched = matched[0]
pos_prompt = pos_prompt_with_style.format(matched['add_prompt_style'])
return pos_prompt
def launch_pipeline(uuid,
pos_prompt,
user_models,
num_images=1,
style_model=None,
multiplier_style=0.25,
pose_model=None,
pose_image=None
):
base_model = 'ly261666/cv_portrait_model'
before_queue_size = inference_threadpool._work_queue.qsize()
before_done_count = inference_done_count
style_model = styles[style_model]['name']
if style_model == styles[0]['name']:
style_model_path = None
else:
matched = list(filter(lambda style: style_model == style['name'], styles))
if len(matched) == 0:
raise ValueError(f'styles not found: {style_model}')
matched = matched[0]
model_dir = snapshot_download(matched['model_id'], revision=matched['revision'])
style_model_path = os.path.join(model_dir, matched['bin_file'])
if pose_image is None:
pose_model_path = None
use_depth_control = False
else:
print(pose_model)
pose_model_path = pose_models[pose_model]['pth']
if pose_model == 1:
use_depth_control = True
else:
use_depth_control = False
print("-------user_models: ", user_models)
if not uuid:
if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
return "请登陆后使用! (Please login first)"
else:
uuid = 'qw'
use_main_model = True
use_face_swap = True
use_post_process = True
use_stylization = False
output_model_name = 'personalizaition_lora'
instance_data_dir = os.path.join('/tmp', uuid, 'training_data', output_model_name)
lora_model_path = f'/tmp/{uuid}/{output_model_name}'
gen_portrait = GenPortrait(pose_model_path, pose_image, use_depth_control, pos_prompt, neg_prompt, style_model_path, multiplier_style, use_main_model,
use_face_swap, use_post_process,
use_stylization)
num_images = min(6, num_images)
future = inference_threadpool.submit(gen_portrait, instance_data_dir,
num_images, base_model, lora_model_path, 'film/film', 'v2.0')
while not future.done():
is_processing = future.running()
if not is_processing:
cur_done_count = inference_done_count
to_wait = before_queue_size - (cur_done_count - before_done_count)
yield ["排队等待资源中,前方还有{}个生成任务, 预计需要等待{}分钟...".format(to_wait, to_wait * 2.5), None]
else:
yield ["生成中, 请耐心等待(Generating)...", None]
time.sleep(1)
outputs = future.result()
outputs_RGB = []
for out_tmp in outputs:
outputs_RGB.append(cv2.cvtColor(out_tmp, cv2.COLOR_BGR2RGB))
image_path = './lora_result.png'
if len(outputs) > 0:
result = concatenate_images(outputs)
cv2.imwrite(image_path, result)
yield ["生成完毕(Generating done)!", outputs_RGB]
else:
yield ["生成失败,请重试(Generating failed, please retry)!", outputs_RGB]
class Trainer:
def __init__(self):
pass
def run(
self,
uuid: str,
instance_images: list,
) -> str:
if not torch.cuda.is_available():
raise gr.Error('CUDA is not available.')
if instance_images is None:
raise gr.Error('您需要上传训练图片(Please upload photos)!')
if len(instance_images) > 10:
raise gr.Error('您需要上传小于10张训练图片(Please upload at most 10 photos)!')
if not uuid:
if os.getenv("MODELSCOPE_ENVIRONMENT") == 'studio':
return "请登陆后使用(Please login first)! "
else:
uuid = 'qw'
output_model_name = 'personalizaition_lora'
# mv user upload data to target dir
instance_data_dir = os.path.join('/tmp', uuid, 'training_data', output_model_name)
print("--------uuid: ", uuid)
if not os.path.exists(f"/tmp/{uuid}"):
os.makedirs(f"/tmp/{uuid}")
work_dir = f"/tmp/{uuid}/{output_model_name}"
print("----------work_dir: ", work_dir)
shutil.rmtree(work_dir, ignore_errors=True)
shutil.rmtree(instance_data_dir, ignore_errors=True)
prepare_dataset([img['name'] for img in instance_images], output_dataset_dir=instance_data_dir)
data_process_fn(instance_data_dir, True)
# train lora
train_lora_fn(foundation_model_path='ly261666/cv_portrait_model',
revision='v2.0',
output_img_dir=instance_data_dir,
work_dir=work_dir)
message = f'训练已经完成!请切换至 [形象体验] 标签体验模型效果(Training done, please switch to the inference tab to generate photos.)'
print(message)
return message
def flash_model_list(uuid):
folder_path = f"/tmp/{uuid}"
folder_list = []
print("------flash_model_list folder_path: ", folder_path)
if not os.path.exists(folder_path):
print('--------The folder_path is missing.')
else:
files = os.listdir(folder_path)
for file in files:
file_path = os.path.join(folder_path, file)
if os.path.isdir(folder_path):
file_lora_path = f"{file_path}/output/pytorch_lora_weights.bin"
if os.path.exists(file_lora_path):
folder_list.append(file)
print("-------folder_list + HOT_MODELS: ", folder_list + HOT_MODELS)
return gr.Radio.update(choices=HOT_MODELS + folder_list)
def upload_file(files, current_files):
file_paths = [file_d['name'] for file_d in current_files] + [file.name for file in files]
return file_paths
def train_input():
trainer = Trainer()
with gr.Blocks() as demo:
uuid = gr.Text(label="modelscope_uuid", visible=False)
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown('训练图片(Training photos)')
instance_images = gr.Gallery()
upload_button = gr.UploadButton("选择图片上传(Upload photos)", file_types=["image"],
file_count="multiple")
clear_button = gr.Button("清空图片(Clear photos)")
clear_button.click(fn=lambda: [], inputs=None, outputs=instance_images)
upload_button.upload(upload_file, inputs=[upload_button, instance_images], outputs=instance_images, queue=False)
gr.Markdown('''
- Step 1. 上传计划训练的图片,3~10张头肩照(注意:请避免图片中出现多人脸、脸部遮挡等情况,否则可能导致效果异常)
- Step 2. 点击 [开始训练] ,启动形象定制化训练,约需15分钟,请耐心等待~
- Step 3. 切换至 [形象体验] ,生成你的风格照片
''')
gr.Markdown('''
- Step 1. Upload 3-10 headshot photos of yours (Note: avoid photos with multiple faces or face obstruction, which may lead to non-ideal result).
- Step 2. Click [Train] to start training for customizing your Digital-Twin, this may take up-to 15 mins.
- Step 3. Switch to [Inference] Tab to generate stylized photos.
''')
run_button = gr.Button('开始训练(等待上传图片加载显示出来再点,否则会报错)'
'Start training (please wait until photo(s) fully uploaded, otherwise it may result in training failure)')
with gr.Box():
gr.Markdown('''
请等待训练完成
Please wait for the training to complete.
''')
output_message = gr.Markdown()
with gr.Box():
gr.Markdown('''
碰到抓狂的错误或者计算资源紧张的情况下,推荐直接在[NoteBook](https://modelscope.cn/my/mynotebook/preset)上进行体验。
安装方法请参考:https://github.com/modelscope/facechain .
If you are experiencing prolonged waiting time, you may try on [ModelScope NoteBook](https://modelscope.cn/my/mynotebook/preset) to prepare your dedicated environment.
You may refer to: https://github.com/modelscope/facechain for installation instruction.
''')
run_button.click(fn=trainer.run,
inputs=[
uuid,
instance_images,
],
outputs=[output_message])
return demo
def inference_input():
with gr.Blocks() as demo:
uuid = gr.Text(label="modelscope_uuid", visible=False)
with gr.Row():
with gr.Column():
user_models = gr.Radio(label="模型选择(Model list)", choices=HOT_MODELS, type="value",
value=HOT_MODELS[0])
style_model_list = []
for style in styles:
style_model_list.append(style['name'])
style_model = gr.Dropdown(choices=style_model_list, value=styles[0]['name'],
type="index", label="风格模型(Style model)")
prompts=[]
for prompt in cloth_prompt:
prompts.append(prompt['name'])
cloth_style = gr.Radio(choices=prompts, value=cloth_prompt[0]['name'],
type="index", label="服装风格(Cloth style)")
pmodels = []
for pmodel in pose_models:
pmodels.append(pmodel['name'])
with gr.Accordion("高级选项(Expert)", open=False):
pos_prompt = gr.Textbox(label="提示语(Prompt)", lines=3,
value=generate_pos_prompt(None, cloth_prompt[0]['prompt']),
interactive=True)
multiplier_style = gr.Slider(minimum=0, maximum=1, value=0.25,
step=0.05, label='风格权重(Multiplier style)')
pose_image = gr.Image(source='upload', type='filepath', label='姿态图片(Pose image)')
gr.Examples(pose_examples['man'], inputs=[pose_image], label='男性姿态示例')
gr.Examples(pose_examples['woman'], inputs=[pose_image], label='女性姿态示例')
pose_model = gr.Radio(choices=pmodels, value=pose_models[0]['name'],
type="index", label="姿态控制模型(Pose control model)")
with gr.Box():
num_images = gr.Number(
label='生成图片数量(Number of photos)', value=6, precision=1, minimum=1, maximum=6)
gr.Markdown('''
注意:最多支持生成6张图片!(You may generate a maximum of 6 photos at one time!)
''')
display_button = gr.Button('开始生成(Start!)')
with gr.Box():
infer_progress = gr.Textbox(label="生成进度(Progress)", value="当前无生成任务(No task)", interactive=False)
with gr.Box():
gr.Markdown('生成结果(Result)')
output_images = gr.Gallery(label='Output', show_label=False).style(columns=3, rows=2, height=600,
object_fit="contain")
style_model.change(update_cloth, style_model, [cloth_style, pos_prompt])
cloth_style.change(update_prompt, [style_model, cloth_style], [pos_prompt])
pose_image.change(update_pose_model, pose_image, [pose_model])
display_button.click(fn=launch_pipeline,
inputs=[uuid, pos_prompt, user_models, num_images, style_model, multiplier_style,
pose_model, pose_image],
outputs=[infer_progress, output_images])
return demo
with gr.Blocks(css='style.css') as demo:
with gr.Tabs():
with gr.TabItem('\N{rocket}形象定制(Train)'):
train_input()
with gr.TabItem('\N{party popper}形象体验(Inference)'):
inference_input()
if __name__ == "__main__":
multiprocessing.set_start_method('spawn')
demo.queue(status_update_rate=1).launch(share=True)