Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust config and namings #42

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion face_swapper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model.
[training.dataset]
file_pattern = .datasets/vggface2/**/*.jpg
warp_template = vgg_face_hq_to_arcface_128_v2
transform_size = 256
batch_mode = equal
batch_ratio = 0.2
resolution = 256
```

```
Expand Down Expand Up @@ -72,6 +72,7 @@ attribute_weight = 10
reconstruction_weight = 20
identity_weight = 20
gaze_weight = 0
gaze_scale_factor = 1
pose_weight = 0
expression_weight = 0
```
Expand Down
3 changes: 2 additions & 1 deletion face_swapper/config.ini
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[training.dataset]
file_pattern =
warp_template =
transform_size =
batch_mode =
batch_ratio =

Expand All @@ -26,14 +27,14 @@ num_filters =
num_layers =
num_discriminators =
kernel_size =
resolution =

[training.losses]
adversarial_weight =
attribute_weight =
reconstruction_weight =
identity_weight =
gaze_weight =
gaze_scale_factor =
pose_weight =
expression_weight =

Expand Down
6 changes: 3 additions & 3 deletions face_swapper/src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@


class DynamicDataset(Dataset[Tensor]):
def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float, resolution : int) -> None:
def __init__(self, file_pattern : str, warp_template : WarpTemplate, transform_size : int, batch_mode : BatchMode, batch_ratio : float) -> None:
self.file_paths = glob.glob(file_pattern)
self.warp_template = warp_template
self.transform_size = transform_size
self.batch_mode = batch_mode
self.batch_ratio = batch_ratio
self.resolution = resolution
self.transforms = self.compose_transforms()

def __getitem__(self, index : int) -> Batch:
Expand All @@ -39,7 +39,7 @@ def compose_transforms(self) -> transforms:
[
AugmentTransform(),
transforms.ToPILImage(),
transforms.Resize((self.resolution, self.resolution), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.Resize((self.transform_size, self.transform_size), interpolation = transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
WarpTransform(self.warp_template),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
Expand Down
2 changes: 1 addition & 1 deletion face_swapper/src/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor:

def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
crop_tensor[:, :, :, :padding[2]] = 0
Expand Down
8 changes: 4 additions & 4 deletions face_swapper/src/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,15 @@ def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tenso
return gaze_loss, weighted_gaze_loss

def detect_gaze(self, input_tensor : Tensor) -> Gaze:
resolution = CONFIG.getint('training.dataset', 'resolution')
scale_factor = resolution / 256
scale_factor = CONFIG.getint('training.losses', 'gaze_scale_factor')
y_min = int(60 * scale_factor)
y_max = int(224 * scale_factor)
x_min = int(16 * scale_factor)
x_max = int(205 * scale_factor)
crop_tensor = input_tensor[:, :, y_min: y_max, x_min: x_max]

crop_tensor = input_tensor[:, :, y_min:y_max, x_min:x_max]
crop_tensor = (crop_tensor + 1) * 0.5
crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor)
crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic')
crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic')
pitch_tensor, yaw_tensor = self.gazer(crop_tensor)
return pitch_tensor, yaw_tensor
9 changes: 4 additions & 5 deletions face_swapper/src/networks/aad.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ def forward(self, source_embedding : Embedding, target_attributes : Attributes)
temp_tensors = self.pixel_shuffle_up_sample(source_embedding)

for index, layer in enumerate(self.layers[:-1]):
temp_shape = target_attributes[index + 1].shape[2:]
temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding)
temp_tensors = nn.functional.interpolate(temp_tensor, temp_shape, mode = 'bilinear', align_corners = False)
temp_size = target_attributes[index + 1].shape[2:]
temp_tensors = nn.functional.interpolate(temp_tensor, temp_size, mode = 'bilinear', align_corners = False)

temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding)
output_tensor = torch.tanh(temp_tensors)
Expand Down Expand Up @@ -113,10 +113,9 @@ def __init__(self, input_channels : int, attribute_channels : int, identity_chan

def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor:
temp_tensor = self.instance_norm(input_tensor)
temp_size = temp_tensor.shape[2:]

if attribute_embedding.shape[2:] != temp_tensor.shape[2:]:
attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_tensor.shape[2:], mode = 'bilinear')

attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_size, mode = 'bilinear')
Copy link
Contributor Author

@henryruhs henryruhs Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

enforce same shape always... could there be something wrong about it?

is this even needed? feels like a hack

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

first two layers need this

attribute_scale = self.conv1(attribute_embedding)
attribute_shift = self.conv2(attribute_embedding)
attribute_modulation = attribute_scale * temp_tensor + attribute_shift
Expand Down
4 changes: 2 additions & 2 deletions face_swapper/src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,15 +198,15 @@ def create_trainer() -> Trainer:
def train() -> None:
dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern')
dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template'))
dataset_transform_size = CONFIG.getint('training.dataset', 'transform_size')
dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode'))
dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio')
dataset_resolution = CONFIG.getint('training.dataset', 'resolution')
output_resume_path = CONFIG.get('training.output', 'resume_path')

if torch.cuda.is_available():
torch.set_float32_matmul_precision('high')

dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio, dataset_resolution)
dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_transform_size, dataset_batch_mode, dataset_batch_ratio)
training_loader, validation_loader = create_loaders(dataset)
face_swapper_trainer = FaceSwapperTrainer()
trainer = create_trainer()
Expand Down