diff --git a/face_swapper/README.md b/face_swapper/README.md index 2395ef2..0e5e413 100644 --- a/face_swapper/README.md +++ b/face_swapper/README.md @@ -29,9 +29,9 @@ This `config.ini` utilizes the MegaFace dataset to train the Face Swapper model. [training.dataset] file_pattern = .datasets/vggface2/**/*.jpg warp_template = vgg_face_hq_to_arcface_128_v2 +transform_size = 256 batch_mode = equal batch_ratio = 0.2 -resolution = 256 ``` ``` @@ -72,6 +72,7 @@ attribute_weight = 10 reconstruction_weight = 20 identity_weight = 20 gaze_weight = 0 +gaze_scale_factor = 1 pose_weight = 0 expression_weight = 0 ``` diff --git a/face_swapper/config.ini b/face_swapper/config.ini index eb2cbe2..052fbe0 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,6 +1,7 @@ [training.dataset] file_pattern = warp_template = +transform_size = batch_mode = batch_ratio = @@ -26,7 +27,6 @@ num_filters = num_layers = num_discriminators = kernel_size = -resolution = [training.losses] adversarial_weight = @@ -34,6 +34,7 @@ attribute_weight = reconstruction_weight = identity_weight = gaze_weight = +gaze_scale_factor = pose_weight = expression_weight = diff --git a/face_swapper/src/dataset.py b/face_swapper/src/dataset.py index ce45a90..393a934 100644 --- a/face_swapper/src/dataset.py +++ b/face_swapper/src/dataset.py @@ -12,12 +12,12 @@ class DynamicDataset(Dataset[Tensor]): - def __init__(self, file_pattern : str, warp_template : WarpTemplate, batch_mode : BatchMode, batch_ratio : float, resolution : int) -> None: + def __init__(self, file_pattern : str, warp_template : WarpTemplate, transform_size : int, batch_mode : BatchMode, batch_ratio : float) -> None: self.file_paths = glob.glob(file_pattern) self.warp_template = warp_template + self.transform_size = transform_size self.batch_mode = batch_mode self.batch_ratio = batch_ratio - self.resolution = resolution self.transforms = self.compose_transforms() def __getitem__(self, index : int) -> Batch: @@ -39,7 +39,7 @@ def compose_transforms(self) -> transforms: [ AugmentTransform(), transforms.ToPILImage(), - transforms.Resize((self.resolution, self.resolution), interpolation = transforms.InterpolationMode.BICUBIC), + transforms.Resize((self.transform_size, self.transform_size), interpolation = transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), WarpTransform(self.warp_template), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py index 557fcbe..97d4cfc 100644 --- a/face_swapper/src/helper.py +++ b/face_swapper/src/helper.py @@ -27,7 +27,7 @@ def warp_tensor(input_tensor : Tensor, warp_template : WarpTemplate) -> Tensor: def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding: crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2') - crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area') + crop_tensor = nn.functional.interpolate(crop_tensor, size = 112, mode = 'area') crop_tensor[:, :, :padding[0], :] = 0 crop_tensor[:, :, 112 - padding[1]:, :] = 0 crop_tensor[:, :, :, :padding[2]] = 0 diff --git a/face_swapper/src/models/loss.py b/face_swapper/src/models/loss.py index 9792bf4..d2b6f40 100644 --- a/face_swapper/src/models/loss.py +++ b/face_swapper/src/models/loss.py @@ -169,15 +169,15 @@ def forward(self, target_tensor : Tensor, output_tensor : Tensor) -> Tuple[Tenso return gaze_loss, weighted_gaze_loss def detect_gaze(self, input_tensor : Tensor) -> Gaze: - resolution = CONFIG.getint('training.dataset', 'resolution') - scale_factor = resolution / 256 + scale_factor = CONFIG.getint('training.losses', 'gaze_scale_factor') y_min = int(60 * scale_factor) y_max = int(224 * scale_factor) x_min = int(16 * scale_factor) x_max = int(205 * scale_factor) - crop_tensor = input_tensor[:, :, y_min: y_max, x_min: x_max] + + crop_tensor = input_tensor[:, :, y_min:y_max, x_min:x_max] crop_tensor = (crop_tensor + 1) * 0.5 crop_tensor = transforms.Normalize(mean = [ 0.485, 0.456, 0.406 ], std = [ 0.229, 0.224, 0.225 ])(crop_tensor) - crop_tensor = nn.functional.interpolate(crop_tensor, size = (448, 448), mode = 'bicubic') + crop_tensor = nn.functional.interpolate(crop_tensor, size = 448, mode = 'bicubic') pitch_tensor, yaw_tensor = self.gazer(crop_tensor) return pitch_tensor, yaw_tensor diff --git a/face_swapper/src/networks/aad.py b/face_swapper/src/networks/aad.py index 7db1761..a06a3c4 100644 --- a/face_swapper/src/networks/aad.py +++ b/face_swapper/src/networks/aad.py @@ -28,9 +28,9 @@ def forward(self, source_embedding : Embedding, target_attributes : Attributes) temp_tensors = self.pixel_shuffle_up_sample(source_embedding) for index, layer in enumerate(self.layers[:-1]): - temp_shape = target_attributes[index + 1].shape[2:] temp_tensor = layer(temp_tensors, target_attributes[index], source_embedding) - temp_tensors = nn.functional.interpolate(temp_tensor, temp_shape, mode = 'bilinear', align_corners = False) + temp_size = target_attributes[index + 1].shape[2:] + temp_tensors = nn.functional.interpolate(temp_tensor, temp_size, mode = 'bilinear', align_corners = False) temp_tensors = self.layers[-1](temp_tensors, target_attributes[-1], source_embedding) output_tensor = torch.tanh(temp_tensors) @@ -113,10 +113,9 @@ def __init__(self, input_channels : int, attribute_channels : int, identity_chan def forward(self, input_tensor : Tensor, attribute_embedding : Embedding, identity_embedding : Embedding) -> Tensor: temp_tensor = self.instance_norm(input_tensor) + temp_size = temp_tensor.shape[2:] - if attribute_embedding.shape[2:] != temp_tensor.shape[2:]: - attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_tensor.shape[2:], mode = 'bilinear') - + attribute_embedding = nn.functional.interpolate(attribute_embedding, size = temp_size, mode = 'bilinear') attribute_scale = self.conv1(attribute_embedding) attribute_shift = self.conv2(attribute_embedding) attribute_modulation = attribute_scale * temp_tensor + attribute_shift diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 821c492..d5e0245 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -198,15 +198,15 @@ def create_trainer() -> Trainer: def train() -> None: dataset_file_pattern = CONFIG.get('training.dataset', 'file_pattern') dataset_warp_template = cast(WarpTemplate, CONFIG.get('training.dataset', 'warp_template')) + dataset_transform_size = CONFIG.getint('training.dataset', 'transform_size') dataset_batch_mode = cast(BatchMode, CONFIG.get('training.dataset', 'batch_mode')) dataset_batch_ratio = CONFIG.getfloat('training.dataset', 'batch_ratio') - dataset_resolution = CONFIG.getint('training.dataset', 'resolution') output_resume_path = CONFIG.get('training.output', 'resume_path') if torch.cuda.is_available(): torch.set_float32_matmul_precision('high') - dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_batch_mode, dataset_batch_ratio, dataset_resolution) + dataset = DynamicDataset(dataset_file_pattern, dataset_warp_template, dataset_transform_size, dataset_batch_mode, dataset_batch_ratio) training_loader, validation_loader = create_loaders(dataset) face_swapper_trainer = FaceSwapperTrainer() trainer = create_trainer()