Skip to content

Commit

Permalink
Fix export using Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Feb 26, 2025
1 parent c487f48 commit e855e4e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 5 deletions.
2 changes: 1 addition & 1 deletion embedding_converter/src/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ def export() -> None:
model = EmbeddingConverterTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(ir_version)
input_tensor = (torch.randn(1, 512), )
input_tensor = torch.randn(1, 512)
torch.onnx.export(model, input_tensor, target_path, input_names = [ 'input' ], output_names = [ 'output' ], opset_version = opset_version)
6 changes: 2 additions & 4 deletions face_swapper/src/exporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from .models.generator import Generator
from .training import FaceSwapperTrainer

CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
Expand All @@ -17,9 +17,7 @@ def export() -> None:
opset_version = CONFIG.getint('exporting', 'opset_version')

makedirs(directory_path, exist_ok = True)
state_dict = torch.load(source_path, map_location = 'cpu').get('state_dict').get('generator')
model = Generator()
model.load_state_dict(state_dict)
model = FaceSwapperTrainer.load_from_checkpoint(source_path, map_location = 'cpu')
model.eval()
model.ir_version = torch.tensor(ir_version)
source_tensor = torch.randn(1, 512)
Expand Down

0 comments on commit e855e4e

Please sign in to comment.