Skip to content

Commit

Permalink
Fix some naming and types
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Mar 3, 2025
1 parent b1e3d79 commit 4dee56a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
40 changes: 28 additions & 12 deletions face_swapper/src/helper.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,48 @@
import torch
from torch import Tensor, nn

from .types import AlignmentMatrices, EmbedderModule, Embedding, Padding
from .types import WarpMatrixSet, EmbedderModule, Embedding, Padding

ALIGNMENT_MATRICES: AlignmentMatrices =\
WARP_MATRIX_SET : WarpMatrixSet =\
{
'__vgg_face_hq__to__arcface_128_v2__': torch.tensor(
'vgg_face_hq_to_arcface_128_v2': torch.tensor(
[
[
[ 1.01305414, -0.00140513, -0.00585911 ],
[ 0.00140513, 1.01305414, 0.11169602 ]
], dtype = torch.float32),
'__arcface_128_v2__to__arcface_112_v2__': torch.tensor(
1.01305414,
-0.00140513,
-0.00585911
],
[
[ 8.75000016e-01, -1.07193451e-08, 3.80446920e-10 ],
[ 1.07193451e-08, 8.75000016e-01, -1.25000007e-01 ]
], dtype = torch.float32)
0.00140513,
1.01305414,
0.11169602
]
], dtype = torch.float32),
'arcface_128_v2_to_arcface_112_v2': torch.tensor(
[
[
8.75000016e-01,
-1.07193451e-08,
3.80446920e-10
],
[
1.07193451e-08,
8.75000016e-01,
-1.25000007e-01
]
], dtype = torch.float32)
}


def warp_tensor(input_tensor : Tensor, alignment_matrix : str) -> Tensor:
matrix = ALIGNMENT_MATRICES.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
matrix = WARP_MATRIX_SET.get(alignment_matrix).repeat(input_tensor.shape[0], 1, 1)
grid = nn.functional.affine_grid(matrix.to(input_tensor.device), list(input_tensor.shape))
output_tensor = nn.functional.grid_sample(input_tensor, grid, align_corners = False, padding_mode = 'reflection')
return output_tensor


def calc_embedding(embedder : EmbedderModule, input_tensor : Tensor, padding : Padding) -> Embedding:
crop_tensor = warp_tensor(input_tensor, '__arcface_128_v2__to__arcface_112_v2__')
crop_tensor = warp_tensor(input_tensor, 'arcface_128_v2_to_arcface_112_v2')
crop_tensor = nn.functional.interpolate(crop_tensor, size = (112, 112), mode = 'area')
crop_tensor[:, :, :padding[0], :] = 0
crop_tensor[:, :, 112 - padding[1]:, :] = 0
Expand Down
2 changes: 1 addition & 1 deletion face_swapper/src/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@

OptimizerConfig : TypeAlias = Any

AlignmentMatrices : TypeAlias = Dict[str, Tensor]
WarpMatrixSet : TypeAlias = Dict[str, Tensor]

0 comments on commit 4dee56a

Please sign in to comment.