-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature] add EnCodec model embeddings to FAD calculation #23
Merged
Merged
Changes from 4 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
9a135c1
[chore] added gitignore for cleaner commiting
ivanlmh 4b7d3bd
[feature] added encodec embeddings to FAD calculation
ivanlmh 5039f58
[feature] fixed-up encodec FAD and added 24khz test to notebook
ivanlmh e7a0672
[feature] added encodec as requirement and 48khz model test to notebook
ivanlmh 67f6f69
[fix] added missing channels variable definition
ivanlmh File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# Python Project | ||
checkpoints/ | ||
__pycache__/ | ||
.conda/ | ||
.pytest_cache/ | ||
|
||
*.egg-info/ | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,8 @@ | |
|
||
from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k | ||
|
||
from encodec import EncodecModel | ||
|
||
|
||
def load_audio_task(fname, sample_rate, dtype="float32"): | ||
if dtype not in ['float64', 'float32', 'int32', 'int16']: | ||
|
@@ -32,7 +34,8 @@ def load_audio_task(fname, sample_rate, dtype="float32"): | |
wav_data = wav_data / float(2**31) | ||
|
||
# Convert to mono | ||
if len(wav_data.shape) > 1: | ||
assert channels in [1, 2], "channels must be 1 or 2" | ||
if len(wav_data.shape) > channels: | ||
wav_data = np.mean(wav_data, axis=1) | ||
|
||
if sr != sample_rate: | ||
|
@@ -48,6 +51,7 @@ def __init__( | |
model_name="vggish", | ||
submodel_name="630k-audioset", # only for CLAP | ||
sample_rate=16000, | ||
channels=1, | ||
use_pca=False, # only for VGGish | ||
use_activation=False, # only for VGGish | ||
verbose=False, | ||
|
@@ -57,21 +61,25 @@ def __init__( | |
"""Initialize FAD | ||
|
||
ckpt_dir: folder where the downloaded checkpoints are stored | ||
model_name: one between vggish, pann or clap | ||
model_name: one between vggish, pann, clap or encodec | ||
submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] | ||
sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use | ||
use_pca: whether to apply PCA to the vggish embeddings | ||
use_activation: whether to use the output activation in vggish | ||
enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used) | ||
""" | ||
assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap" | ||
assert model_name in ["vggish", "pann", "clap", "encodec"], "model_name must be either 'vggish', 'pann', 'clap', or 'encodec'" | ||
if model_name == "vggish": | ||
assert sample_rate == 16000, "sample_rate must be 16000" | ||
elif model_name == "pann": | ||
assert sample_rate in [8000, 16000, 32000], "sample_rate must be 8000, 16000 or 32000" | ||
elif model_name == "clap": | ||
assert sample_rate == 48000, "sample_rate must be 48000" | ||
assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] | ||
elif model_name == "encodec": | ||
assert sample_rate in [24000, 48000], "sample_rate must be 24000 or 48000" | ||
if sample_rate == 48000: | ||
assert channels == 2, "channels must be 2 for 48khz encodec model" | ||
self.model_name = model_name | ||
self.submodel_name = submodel_name | ||
self.sample_rate = sample_rate | ||
|
@@ -197,11 +205,23 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): | |
device=self.device) | ||
self.model.load_ckpt(model_path) | ||
|
||
# encodec | ||
elif model_name == "encodec": | ||
# choose the right model based on sample_rate | ||
# weights are loaded from the encodec repo: https://github.com/facebookresearch/encodec/ | ||
if self.sample_rate == 24000: | ||
self.model = EncodecModel.encodec_model_24khz() | ||
elif self.sample_rate == 48000: | ||
self.model = EncodecModel.encodec_model_48khz() | ||
# 24kbps is the max bandwidth supported by both versions | ||
# these models use 32 residual quantizers | ||
self.model.set_target_bandwidth(24.0) | ||
|
||
self.model.eval() | ||
|
||
def get_embeddings(self, x, sr): | ||
""" | ||
Get embeddings using VGGish model. | ||
Get embeddings using VGGish, PANN, CLAP or EnCodec models. | ||
Params: | ||
-- x : a list of np.ndarray audio samples | ||
-- sr : Sampling rate, if x is a list of audio samples. Default value is 16000. | ||
|
@@ -219,8 +239,39 @@ def get_embeddings(self, x, sr): | |
elif self.model_name == "clap": | ||
audio = torch.tensor(audio).float().unsqueeze(0) | ||
embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True) | ||
|
||
if self.device == torch.device('cuda'): | ||
elif self.model_name == "encodec": | ||
# add two dimensions | ||
audio = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0) | ||
# if SAMPLE_RATE is 48000, we need to make audio stereo | ||
if self.model.sample_rate == 48000: | ||
if audio.shape[-1] != 2: | ||
print( | ||
"[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..." | ||
) | ||
audio = torch.cat((audio, audio), dim=1) | ||
else: | ||
# transpose to (batch, channels, samples) | ||
audio = audio[:, 0].transpose(1, 2) | ||
|
||
if self.verbose: | ||
print( | ||
"[Frechet Audio Distance] Audio shape: {}".format( | ||
audio.shape | ||
) | ||
) | ||
|
||
with torch.no_grad(): | ||
# encodec embedding (before quantization) | ||
embd = self.model.encoder(audio) | ||
embd = embd.squeeze(0) | ||
|
||
if self.verbose: | ||
print( | ||
"[Frechet Audio Distance] Embedding shape: {}".format( | ||
embd.shape | ||
) | ||
) | ||
if self.device == torch.device("cuda"): | ||
embd = embd.cpu() | ||
|
||
if torch.is_tensor(embd): | ||
|
@@ -309,8 +360,8 @@ def update(*a): | |
for fname in os.listdir(dir): | ||
res = pool.apply_async( | ||
load_audio_task, | ||
args=(os.path.join(dir, fname), self.sample_rate, dtype), | ||
callback=update | ||
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, it looks like I erased the definition when commiting |
||
callback=update, | ||
) | ||
task_results.append(res) | ||
pool.close() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ dependencies = [ | |
'laion_clap', | ||
'transformers<=4.30.2', | ||
'torchaudio', | ||
'encodec', | ||
] | ||
|
||
[project.urls] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,4 +9,5 @@ torchvision | |
|
||
laion_clap | ||
transformers<=4.30.2 | ||
torchaudio | ||
torchaudio | ||
encodec |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
channels
here is unspecified, I guess you miss a line before this to calculatechannels
fromwav_data
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it's fixed now. It's an input to the function, defaulting to 1 unless specified (as should be done for encodec 48khz)