-
Notifications
You must be signed in to change notification settings - Fork 107
/
loss.py
254 lines (215 loc) · 8.03 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# Copyright (c) 2024 NVIDIA CORPORATION.
# Licensed under the MIT license.
# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
# LICENSE is in incl_licenses directory.
import torch
import torch.nn.functional as F
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import Optional, List, Union, Dict, Tuple
from collections import namedtuple
import math
import functools
# Adapted from https://github.com/descriptinc/descript-audio-codec/blob/main/dac/nn/loss.py under the MIT license.
# LICENSE is in incl_licenses directory.
class MultiScaleMelSpectrogramLoss(nn.Module):
"""Compute distance between mel spectrograms. Can be used
in a multi-scale way.
Parameters
----------
n_mels : List[int]
Number of mels per STFT, by default [5, 10, 20, 40, 80, 160, 320],
window_lengths : List[int], optional
Length of each window of each STFT, by default [32, 64, 128, 256, 512, 1024, 2048]
loss_fn : typing.Callable, optional
How to compare each loss, by default nn.L1Loss()
clamp_eps : float, optional
Clamp on the log magnitude, below, by default 1e-5
mag_weight : float, optional
Weight of raw magnitude portion of loss, by default 0.0 (no ampliciation on mag part)
log_weight : float, optional
Weight of log magnitude portion of loss, by default 1.0
pow : float, optional
Power to raise magnitude to before taking log, by default 1.0
weight : float, optional
Weight of this loss, by default 1.0
match_stride : bool, optional
Whether to match the stride of convolutional layers, by default False
Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py
Additional code copied and modified from https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
def __init__(
self,
sampling_rate: int,
n_mels: List[int] = [5, 10, 20, 40, 80, 160, 320],
window_lengths: List[int] = [32, 64, 128, 256, 512, 1024, 2048],
loss_fn: typing.Callable = nn.L1Loss(),
clamp_eps: float = 1e-5,
mag_weight: float = 0.0,
log_weight: float = 1.0,
pow: float = 1.0,
weight: float = 1.0,
match_stride: bool = False,
mel_fmin: List[float] = [0, 0, 0, 0, 0, 0, 0],
mel_fmax: List[float] = [None, None, None, None, None, None, None],
window_type: str = "hann",
):
super().__init__()
self.sampling_rate = sampling_rate
STFTParams = namedtuple(
"STFTParams",
["window_length", "hop_length", "window_type", "match_stride"],
)
self.stft_params = [
STFTParams(
window_length=w,
hop_length=w // 4,
match_stride=match_stride,
window_type=window_type,
)
for w in window_lengths
]
self.n_mels = n_mels
self.loss_fn = loss_fn
self.clamp_eps = clamp_eps
self.log_weight = log_weight
self.mag_weight = mag_weight
self.weight = weight
self.mel_fmin = mel_fmin
self.mel_fmax = mel_fmax
self.pow = pow
@staticmethod
@functools.lru_cache(None)
def get_window(
window_type,
window_length,
):
return signal.get_window(window_type, window_length)
@staticmethod
@functools.lru_cache(None)
def get_mel_filters(sr, n_fft, n_mels, fmin, fmax):
return librosa_mel_fn(sr=sr, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
def mel_spectrogram(
self,
wav,
n_mels,
fmin,
fmax,
window_length,
hop_length,
match_stride,
window_type,
):
"""
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
"""
B, C, T = wav.shape
if match_stride:
assert (
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
right_pad = 0
pad = 0
wav = torch.nn.functional.pad(wav, (pad, pad + right_pad), mode="reflect")
window = self.get_window(window_type, window_length)
window = torch.from_numpy(window).to(wav.device).float()
stft = torch.stft(
wav.reshape(-1, T),
n_fft=window_length,
hop_length=hop_length,
window=window,
return_complex=True,
center=True,
)
_, nf, nt = stft.shape
stft = stft.reshape(B, C, nf, nt)
if match_stride:
"""
Drop first two and last two frames, which are added, because of padding. Now num_frames * hop_length = num_samples.
"""
stft = stft[..., 2:-2]
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
return mel_spectrogram
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes mel loss between an estimate and a reference
signal.
Parameters
----------
x : torch.Tensor
Estimate signal
y : torch.Tensor
Reference signal
Returns
-------
torch.Tensor
Mel loss.
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
"fmax": fmax,
"window_length": s.window_length,
"hop_length": s.hop_length,
"match_stride": s.match_stride,
"window_type": s.window_type,
}
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
return loss
# Loss functions
def feature_loss(
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
loss += torch.mean(torch.abs(rl - gl))
return loss * 2 # This equates to lambda=2.0 for the feature matching loss
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
r_loss = torch.mean((1 - dr) ** 2)
g_loss = torch.mean(dg**2)
loss += r_loss + g_loss
r_losses.append(r_loss.item())
g_losses.append(g_loss.item())
return loss, r_losses, g_losses
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:
l = torch.mean((1 - dg) ** 2)
gen_losses.append(l)
loss += l
return loss, gen_losses