forked from FuxiVirtualHuman/styletalk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathface_model.py
127 lines (105 loc) · 4.21 KB
/
face_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import functools
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import generators.flow_util as flow_util
from generators.base_function import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder
class FaceGenerator(nn.Module):
def __init__(
self,
mapping_net,
warpping_net,
editing_net,
common
):
super(FaceGenerator, self).__init__()
self.mapping_net = MappingNet(**mapping_net)
self.warpping_net = WarpingNet(**warpping_net, **common)
self.editing_net = EditingNet(**editing_net, **common)
def forward(
self,
input_image,
driving_source,
stage=None
):
if stage == 'warp':
descriptor = self.mapping_net(driving_source)
output = self.warpping_net(input_image, descriptor)
else:
descriptor = self.mapping_net(driving_source)
output = self.warpping_net(input_image, descriptor)
output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor)
return output
class MappingNet(nn.Module):
def __init__(self, coeff_nc, descriptor_nc, layer):
super( MappingNet, self).__init__()
self.layer = layer
nonlinearity = nn.LeakyReLU(0.1)
self.first = nn.Sequential(
torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
for i in range(layer):
net = nn.Sequential(nonlinearity,
torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
setattr(self, 'encoder' + str(i), net)
self.pooling = nn.AdaptiveAvgPool1d(1)
self.output_nc = descriptor_nc
def forward(self, input_3dmm):
out = self.first(input_3dmm)
for i in range(self.layer):
model = getattr(self, 'encoder' + str(i))
out = model(out) + out[:,:,3:-3]
out = self.pooling(out)
return out
class WarpingNet(nn.Module):
def __init__(
self,
image_nc,
descriptor_nc,
base_nc,
max_nc,
encoder_layer,
decoder_layer,
use_spect
):
super( WarpingNet, self).__init__()
nonlinearity = nn.LeakyReLU(0.1)
norm_layer = functools.partial(LayerNorm2d, affine=True)
kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect}
self.descriptor_nc = descriptor_nc
self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc,
max_nc, encoder_layer, decoder_layer, **kwargs)
self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc),
nonlinearity,
nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3))
self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, input_image, descriptor):
final_output={}
output = self.hourglass(input_image, descriptor)
final_output['flow_field'] = self.flow_out(output)
deformation = flow_util.convert_flow_to_deformation(final_output['flow_field'])
final_output['warp_image'] = flow_util.warp_image(input_image, deformation)
return final_output
class EditingNet(nn.Module):
def __init__(
self,
image_nc,
descriptor_nc,
layer,
base_nc,
max_nc,
num_res_blocks,
use_spect):
super(EditingNet, self).__init__()
nonlinearity = nn.LeakyReLU(0.1)
norm_layer = functools.partial(LayerNorm2d, affine=True)
kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect}
self.descriptor_nc = descriptor_nc
# encoder part
self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs)
self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs)
def forward(self, input_image, warp_image, descriptor):
x = torch.cat([input_image, warp_image], 1)
x = self.encoder(x)
gen_image = self.decoder(x, descriptor)
return gen_image