-
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1670eff
Showing
22 changed files
with
459 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
root = true | ||
|
||
[*] | ||
end_of_line = lf | ||
insert_final_newline = true | ||
indent_size = 4 | ||
indent_style = tab | ||
trim_trailing_whitespace = true |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
[flake8] | ||
select = E3, E4, F, I1, I2 | ||
plugins = flake8-import-order | ||
application_import_names = arcface_converter | ||
import-order-style = pycharm | ||
per-file-ignores = preparing.py:E402 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
github: henryruhs | ||
custom: [ buymeacoffee.com/henryruhs, paypal.me/henryruhs ] |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
name: ci | ||
|
||
on: [ push, pull_request ] | ||
|
||
jobs: | ||
lint: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout | ||
uses: actions/checkout@v4 | ||
- name: Set up Python 3.10 | ||
uses: actions/setup-python@v5 | ||
with: | ||
python-version: '3.10' | ||
- run: pip install flake8 | ||
- run: pip install flake8-import-order | ||
- run: pip install mypy | ||
- run: flake8 arcface_converter | ||
- run: mypy arcface_converter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
.idea | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
MIT license | ||
|
||
Copyright (c) 2024 Henry Ruhs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
FaceFusion Labs | ||
=============== | ||
|
||
> Industry leading face manipulation platform. | ||
[](https://github.com/facefusion/facefusion-labs/actions?query=workflow:ci) | ||
 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
ArcFace Converter | ||
================= | ||
|
||
> Convert face embeddings between various ArcFace models. | ||
|
||
Preview | ||
------- | ||
|
||
 | ||
|
||
|
||
Installation | ||
------------ | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
|
||
Example | ||
------- | ||
|
||
This example utilizes the MegaFace dataset to train an ArcFace Converter for SimSwap. | ||
|
||
``` | ||
[preparing.dataset] | ||
dataset_path = datasets/megaface/train.rec | ||
crop_size = 112 | ||
process_limit = 650000 | ||
[preparing.model] | ||
source_path = models/arcface_w600k_r50.onnx | ||
target_path = models/arcface_simswap.onnx | ||
[preparing.input] | ||
directory_path = inputs | ||
source_path = inputs/arcface_w600k_r50.npy | ||
target_path = inputs/arcface_simswap.npy | ||
[training.loader] | ||
split_ratio = 0.8 | ||
batch_size = 51200 | ||
num_workers = 8 | ||
[training.trainer] | ||
max_epochs = 4096 | ||
[training.output] | ||
directory_path = outputs | ||
file_pattern = arcface_converter_simswap_{epoch:02d}_{val_loss:.4f} | ||
[exporting] | ||
directory_path = exports | ||
source_path = outputs/last.ckpt | ||
target_path = exports/arcface_converter_simswap.onnx | ||
opset_version = 15 | ||
[execution] | ||
providers = CUDAExecutionProvider | ||
``` | ||
|
||
|
||
Preparing | ||
--------- | ||
|
||
Prepare the face embedding pairs. | ||
|
||
``` | ||
python prepare.py | ||
``` | ||
|
||
|
||
Training | ||
-------- | ||
|
||
Train the ArcFace converter model. | ||
|
||
``` | ||
python train.py | ||
``` | ||
|
||
|
||
Exporting | ||
--------- | ||
|
||
Export the model to ONNX. | ||
|
||
``` | ||
python export.py | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
[preparing.dataset] | ||
dataset_path = | ||
crop_size = | ||
process_limit = | ||
|
||
[preparing.model] | ||
source_path = | ||
target_path = | ||
|
||
[preparing.input] | ||
directory_path = | ||
source_path = | ||
target_path = | ||
|
||
[training.loader] | ||
split_ratio = | ||
batch_size = | ||
num_workers = | ||
|
||
[training.trainer] | ||
max_epochs = | ||
|
||
[training.output] | ||
directory_path = | ||
file_pattern = | ||
|
||
[exporting] | ||
directory_path = | ||
source_path = | ||
target_path = | ||
opset_version = | ||
|
||
[execution] | ||
providers = |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from src.exporting import export | ||
|
||
if __name__ == '__main__': | ||
export() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
#!/usr/bin/env python3 | ||
|
||
from src.preparing import prepare | ||
|
||
if __name__ == '__main__': | ||
prepare() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
import configparser | ||
from os import makedirs | ||
|
||
import torch | ||
|
||
from .training import ArcFaceConverterTrainer | ||
|
||
CONFIG = configparser.ConfigParser() | ||
CONFIG.read('config.ini') | ||
|
||
|
||
def export() -> None: | ||
directory_path = CONFIG.get('exporting', 'directory_path') | ||
source_path = CONFIG.get('exporting', 'source_path') | ||
target_path = CONFIG.get('exporting', 'target_path') | ||
opset_version = CONFIG.getint('exporting', 'opset_version') | ||
|
||
makedirs(directory_path, exist_ok = True) | ||
model = ArcFaceConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu') | ||
model.eval() | ||
input_tensor = torch.randn(1, 512) | ||
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch import Tensor | ||
|
||
|
||
class ArcFaceConverter(nn.Module): | ||
def __init__(self) -> None: | ||
super(ArcFaceConverter, self).__init__() | ||
self.fc1 = nn.Linear(512, 1024) | ||
self.fc2 = nn.Linear(1024, 2048) | ||
self.fc3 = nn.Linear(2048, 1024) | ||
self.fc4 = nn.Linear(1024, 512) | ||
self.activation = nn.LeakyReLU() | ||
|
||
def forward(self, inputs : Tensor) -> Tensor: | ||
norm_inputs = inputs / torch.norm(inputs) | ||
outputs = self.activation(self.fc1(norm_inputs)) | ||
outputs = self.activation(self.fc2(outputs)) | ||
outputs = self.activation(self.fc3(outputs)) | ||
outputs = self.fc4(outputs) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import configparser | ||
from os import makedirs | ||
from os.path import isfile | ||
from typing import List | ||
|
||
import numpy | ||
numpy.bool = numpy.bool_ | ||
from mxnet.io import ImageRecordIter | ||
from onnxruntime import InferenceSession | ||
from tqdm import tqdm | ||
|
||
from .typing import Embedding, EmbeddingPairs, VisionFrame | ||
|
||
CONFIG = configparser.ConfigParser() | ||
CONFIG.read('config.ini') | ||
|
||
|
||
def prepare_crop_vision_frame(crop_vision_frame : VisionFrame) -> VisionFrame: | ||
crop_vision_frame = crop_vision_frame.astype(numpy.float32) / 255 | ||
crop_vision_frame = (crop_vision_frame - 0.5) * 2 | ||
return crop_vision_frame | ||
|
||
|
||
def create_inference_session(model_path : str, execution_providers : List[str]) -> InferenceSession: | ||
inference_session = InferenceSession(model_path, providers = execution_providers) | ||
return inference_session | ||
|
||
|
||
def forward(inference_session : InferenceSession, crop_vision_frame : VisionFrame) -> Embedding: | ||
embedding = inference_session.run(None, | ||
{ | ||
'input': crop_vision_frame | ||
})[0] | ||
|
||
return embedding | ||
|
||
|
||
def process_embeddings(dataset_reader : ImageRecordIter, source_inference_session : InferenceSession, target_inference_session : InferenceSession) -> EmbeddingPairs: | ||
dataset_process_limit = CONFIG.getint('preparing.dataset', 'process_limit') | ||
embedding_pairs = [] | ||
|
||
with tqdm(total = dataset_process_limit) as progress: | ||
for batch in dataset_reader: | ||
crop_vision_frame = batch.data[0].asnumpy() | ||
crop_vision_frame = prepare_crop_vision_frame(crop_vision_frame) | ||
source_embedding = forward(source_inference_session, crop_vision_frame) | ||
target_embedding = forward(target_inference_session, crop_vision_frame) | ||
embedding_pairs.append([ source_embedding, target_embedding ]) | ||
progress.update() | ||
|
||
if progress.n == dataset_process_limit: | ||
return numpy.concatenate(embedding_pairs, axis = 1).T | ||
|
||
return numpy.concatenate(embedding_pairs, axis = 1).T | ||
|
||
|
||
def prepare() -> None: | ||
dataset_path = CONFIG.get('preparing.dataset', 'dataset_path') | ||
dataset_crop_size = CONFIG.getint('preparing.dataset', 'crop_size') | ||
model_source_path = CONFIG.get('preparing.model', 'source_path') | ||
model_target_path = CONFIG.get('preparing.model', 'target_path') | ||
input_directory_path = CONFIG.get('preparing.input', 'directory_path') | ||
input_source_path = CONFIG.get('preparing.input', 'source_path') | ||
input_target_path = CONFIG.get('preparing.input', 'target_path') | ||
execution_providers = CONFIG.get('execution', 'providers').split(' ') | ||
|
||
makedirs(input_directory_path, exist_ok = True) | ||
if isfile(dataset_path) and isfile(model_source_path) and isfile(model_target_path): | ||
dataset_reader = ImageRecordIter( | ||
path_imgrec = dataset_path, | ||
data_shape = (3, dataset_crop_size, dataset_crop_size), | ||
batch_size = 1, | ||
shuffle = False | ||
) | ||
source_inference_session = create_inference_session(model_source_path, execution_providers) | ||
target_inference_session = create_inference_session(model_target_path, execution_providers) | ||
embedding_pairs = process_embeddings(dataset_reader, source_inference_session, target_inference_session) | ||
numpy.save(input_source_path, embedding_pairs[..., 0].T) | ||
numpy.save(input_target_path, embedding_pairs[..., 1].T) |
Oops, something went wrong.