Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature] add EnCodec model embeddings to FAD calculation #23

Merged
merged 5 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Python Project
checkpoints/
__pycache__/
.conda/
.pytest_cache/

*.egg-info/
```
67 changes: 59 additions & 8 deletions frechet_audio_distance/fad.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k

from encodec import EncodecModel


def load_audio_task(fname, sample_rate, dtype="float32"):
if dtype not in ['float64', 'float32', 'int32', 'int16']:
Expand All @@ -32,7 +34,8 @@ def load_audio_task(fname, sample_rate, dtype="float32"):
wav_data = wav_data / float(2**31)

# Convert to mono
if len(wav_data.shape) > 1:
assert channels in [1, 2], "channels must be 1 or 2"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

channels here is unspecified, I guess you miss a line before this to calculatechannels from wav_data

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's fixed now. It's an input to the function, defaulting to 1 unless specified (as should be done for encodec 48khz)

if len(wav_data.shape) > channels:
wav_data = np.mean(wav_data, axis=1)

if sr != sample_rate:
Expand All @@ -48,6 +51,7 @@ def __init__(
model_name="vggish",
submodel_name="630k-audioset", # only for CLAP
sample_rate=16000,
channels=1,
use_pca=False, # only for VGGish
use_activation=False, # only for VGGish
verbose=False,
Expand All @@ -57,21 +61,25 @@ def __init__(
"""Initialize FAD

ckpt_dir: folder where the downloaded checkpoints are stored
model_name: one between vggish, pann or clap
model_name: one between vggish, pann, clap or encodec
submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use
use_pca: whether to apply PCA to the vggish embeddings
use_activation: whether to use the output activation in vggish
enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used)
"""
assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap"
assert model_name in ["vggish", "pann", "clap", "encodec"], "model_name must be either 'vggish', 'pann', 'clap', or 'encodec'"
if model_name == "vggish":
assert sample_rate == 16000, "sample_rate must be 16000"
elif model_name == "pann":
assert sample_rate in [8000, 16000, 32000], "sample_rate must be 8000, 16000 or 32000"
elif model_name == "clap":
assert sample_rate == 48000, "sample_rate must be 48000"
assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"]
elif model_name == "encodec":
assert sample_rate in [24000, 48000], "sample_rate must be 24000 or 48000"
if sample_rate == 48000:
assert channels == 2, "channels must be 2 for 48khz encodec model"
self.model_name = model_name
self.submodel_name = submodel_name
self.sample_rate = sample_rate
Expand Down Expand Up @@ -197,11 +205,23 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False):
device=self.device)
self.model.load_ckpt(model_path)

# encodec
elif model_name == "encodec":
# choose the right model based on sample_rate
# weights are loaded from the encodec repo: https://github.com/facebookresearch/encodec/
if self.sample_rate == 24000:
self.model = EncodecModel.encodec_model_24khz()
elif self.sample_rate == 48000:
self.model = EncodecModel.encodec_model_48khz()
# 24kbps is the max bandwidth supported by both versions
# these models use 32 residual quantizers
self.model.set_target_bandwidth(24.0)

self.model.eval()

def get_embeddings(self, x, sr):
"""
Get embeddings using VGGish model.
Get embeddings using VGGish, PANN, CLAP or EnCodec models.
Params:
-- x : a list of np.ndarray audio samples
-- sr : Sampling rate, if x is a list of audio samples. Default value is 16000.
Expand All @@ -219,8 +239,39 @@ def get_embeddings(self, x, sr):
elif self.model_name == "clap":
audio = torch.tensor(audio).float().unsqueeze(0)
embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True)

if self.device == torch.device('cuda'):
elif self.model_name == "encodec":
# add two dimensions
audio = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0)
# if SAMPLE_RATE is 48000, we need to make audio stereo
if self.model.sample_rate == 48000:
if audio.shape[-1] != 2:
print(
"[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..."
)
audio = torch.cat((audio, audio), dim=1)
else:
# transpose to (batch, channels, samples)
audio = audio[:, 0].transpose(1, 2)

if self.verbose:
print(
"[Frechet Audio Distance] Audio shape: {}".format(
audio.shape
)
)

with torch.no_grad():
# encodec embedding (before quantization)
embd = self.model.encoder(audio)
embd = embd.squeeze(0)

if self.verbose:
print(
"[Frechet Audio Distance] Embedding shape: {}".format(
embd.shape
)
)
if self.device == torch.device("cuda"):
embd = embd.cpu()

if torch.is_tensor(embd):
Expand Down Expand Up @@ -309,8 +360,8 @@ def update(*a):
for fname in os.listdir(dir):
res = pool.apply_async(
load_audio_task,
args=(os.path.join(dir, fname), self.sample_rate, dtype),
callback=update
args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect self.channels is not defined here, but please correct me if I am wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, it looks like I erased the definition when commiting

callback=update,
)
task_results.append(res)
pool.close()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ dependencies = [
'laion_clap',
'transformers<=4.30.2',
'torchaudio',
'encodec',
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ torchvision

laion_clap
transformers<=4.30.2
torchaudio
torchaudio
encodec
100 changes: 99 additions & 1 deletion test/test_all.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,104 @@
"shutil.rmtree(\"test2\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### EnCodec"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# EnCodec is a model trained as a neural codec, that is, it is trained to compress audio into a latent space and then reconstruct it.\n",
"# One is able to obtain high quality reconstruction from the generated embeddings.\n",
"# It encodes 1 second of audio into 75 embeddings of 128 dimensions each.\n",
"SAMPLE_RATE = 24000\n",
"LENGTH_IN_SECONDS = 1\n",
"\n",
"frechet = FrechetAudioDistance(\n",
" ckpt_dir=\"../checkpoints/encodec\",\n",
" model_name=\"encodec\",\n",
" # submodel_name=\"music_speech_audioset\", # for CLAP only\n",
" sample_rate=SAMPLE_RATE,\n",
" # use_pca=True, # for VGGish only\n",
" # use_activation=False, # for VGGish only\n",
" verbose=True,\n",
" audio_load_worker=8,\n",
" # enable_fusion=False, # for CLAP only\n",
")\n",
"\n",
"for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n",
" os.makedirs(traget, exist_ok=True)\n",
" frequencies = np.linspace(100, 1000, count).tolist()\n",
" for freq in frequencies:\n",
" samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n",
" filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n",
" print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n",
" sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n",
"\n",
"fad_score = frechet.score(\"background\", \"test1\")\n",
"print(\"FAD score test 1: %.8f\" % fad_score)\n",
"\n",
"fad_score = frechet.score(\"background\", \"test2\")\n",
"print(\"FAD score test 2: %.8f\" % fad_score)\n",
"\n",
"shutil.rmtree(\"background\")\n",
"shutil.rmtree(\"test1\")\n",
"shutil.rmtree(\"test2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# EnCodec's 48kHz version maintains the embedding size, so compression is doubled.\n",
"# The model available for 48kHz audio expects stereo audio, so the input audio must have 2 channels.\n",
"SAMPLE_RATE = 48000\n",
"LENGTH_IN_SECONDS = 1\n",
"\n",
"frechet = FrechetAudioDistance(\n",
" ckpt_dir=\"../checkpoints/encodec\",\n",
" model_name=\"encodec\",\n",
" # submodel_name=\"music_speech_audioset\", # for CLAP only\n",
" sample_rate=SAMPLE_RATE,\n",
" channels=2,\n",
" # use_pca=True, # for VGGish only\n",
" # use_activation=False, # for VGGish only\n",
" verbose=True,\n",
" audio_load_worker=8,\n",
" # enable_fusion=False, # for CLAP only\n",
")\n",
"\n",
"for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n",
" os.makedirs(traget, exist_ok=True)\n",
" frequencies = np.linspace(100, 1000, count).tolist()\n",
" for freq in frequencies:\n",
" samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n",
" filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n",
" # make audio stereo\n",
" samples = np.stack([samples, samples], axis=1)\n",
"\n",
" print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n",
" sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n",
"\n",
"fad_score = frechet.score(\"background\", \"test1\")\n",
"print(\"FAD score test 1: %.8f\" % fad_score)\n",
"\n",
"fad_score = frechet.score(\"background\", \"test2\")\n",
"print(\"FAD score test 2: %.8f\" % fad_score)\n",
"\n",
"shutil.rmtree(\"background\")\n",
"shutil.rmtree(\"test1\")\n",
"shutil.rmtree(\"test2\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -618,7 +716,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.12"
"version": "3.10.13"
},
"orig_nbformat": 4
},
Expand Down