forked from ddPn08/rvc-webui
-
Notifications
You must be signed in to change notification settings - Fork 3
/
merge.py
81 lines (71 loc) · 2.41 KB
/
merge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from collections import OrderedDict
from typing import *
import torch
import tqdm
def merge(
path_a: str,
path_b: str,
path_c: str,
alpha: float,
weights: Dict[str, float],
method: str,
):
def extract(ckpt: Dict[str, Any]):
a = ckpt["model"]
opt = OrderedDict()
opt["weight"] = {}
for key in a.keys():
if "enc_q" in key:
continue
opt["weight"][key] = a[key]
return opt
def load_weight(path: str):
print(f"Loading {path}...")
state_dict = torch.load(path, map_location="cpu")
if "model" in state_dict:
weight = extract(state_dict)
else:
weight = state_dict["weight"]
return weight, state_dict
def get_alpha(key: str):
try:
filtered = sorted(
[x for x in weights.keys() if key.startswith(x)], key=len, reverse=True
)
if len(filtered) < 1:
return alpha
return weights[filtered[0]]
except:
return alpha
weight_a, state_dict = load_weight(path_a)
weight_b, _ = load_weight(path_b)
if path_c is not None:
weight_c, _ = load_weight(path_c)
if sorted(list(weight_a.keys())) != sorted(list(weight_b.keys())):
raise RuntimeError("Failed to merge models.")
merged = OrderedDict()
merged["weight"] = {}
def merge_weight(a, b, c, alpha):
if method == "weight_sum":
return (1 - alpha) * a + alpha * b
elif method == "add_diff":
return a + (b - c) * alpha
for key in tqdm.tqdm(weight_a.keys()):
a = get_alpha(key)
if path_c is not None:
merged["weight"][key] = merge_weight(
weight_a[key], weight_b[key], weight_c[key], a
)
else:
merged["weight"][key] = merge_weight(weight_a[key], weight_b[key], None, a)
merged["config"] = state_dict["config"]
merged["params"] = state_dict["params"] if "params" in state_dict else None
merged["version"] = state_dict.get("version", "v1")
merged["sr"] = state_dict["sr"]
merged["f0"] = state_dict["f0"]
merged["info"] = state_dict["info"]
merged["embedder_name"] = (
state_dict["embedder_name"] if "embedder_name" in state_dict else None
)
merged["embedder_output_layer"] = state_dict.get("embedder_output_layer", "12")
return merged