diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3b3b2ac --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +# Python Project +checkpoints/ +__pycache__/ +.conda/ +.pytest_cache/ + +*.egg-info/ +``` \ No newline at end of file diff --git a/frechet_audio_distance/fad.py b/frechet_audio_distance/fad.py index d07c068..c724bdc 100644 --- a/frechet_audio_distance/fad.py +++ b/frechet_audio_distance/fad.py @@ -19,8 +19,10 @@ from .models.pann import Cnn14, Cnn14_8k, Cnn14_16k +from encodec import EncodecModel -def load_audio_task(fname, sample_rate, dtype="float32"): + +def load_audio_task(fname, sample_rate, channels, dtype="float32"): if dtype not in ['float64', 'float32', 'int32', 'int16']: raise ValueError(f"dtype not supported: {dtype}") @@ -32,7 +34,8 @@ def load_audio_task(fname, sample_rate, dtype="float32"): wav_data = wav_data / float(2**31) # Convert to mono - if len(wav_data.shape) > 1: + assert channels in [1, 2], "channels must be 1 or 2" + if len(wav_data.shape) > channels: wav_data = np.mean(wav_data, axis=1) if sr != sample_rate: @@ -48,6 +51,7 @@ def __init__( model_name="vggish", submodel_name="630k-audioset", # only for CLAP sample_rate=16000, + channels=1, use_pca=False, # only for VGGish use_activation=False, # only for VGGish verbose=False, @@ -57,14 +61,14 @@ def __init__( """Initialize FAD ckpt_dir: folder where the downloaded checkpoints are stored - model_name: one between vggish, pann or clap + model_name: one between vggish, pann, clap or encodec submodel_name: only for clap models - determines which checkpoint to use. options: ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] sample_rate: one between [8000, 16000, 32000, 48000]. depending on the model set the sample rate to use use_pca: whether to apply PCA to the vggish embeddings use_activation: whether to use the output activation in vggish enable_fusion: whether to use fusion for clap models (valid depending on the specific submodel used) """ - assert model_name in ["vggish", "pann", "clap"], "model_name must be either 'vggish', 'pann' or 'clap" + assert model_name in ["vggish", "pann", "clap", "encodec"], "model_name must be either 'vggish', 'pann', 'clap', or 'encodec'" if model_name == "vggish": assert sample_rate == 16000, "sample_rate must be 16000" elif model_name == "pann": @@ -72,9 +76,14 @@ def __init__( elif model_name == "clap": assert sample_rate == 48000, "sample_rate must be 48000" assert submodel_name in ["630k-audioset", "630k", "music_audioset", "music_speech", "music_speech_audioset"] + elif model_name == "encodec": + assert sample_rate in [24000, 48000], "sample_rate must be 24000 or 48000" + if sample_rate == 48000: + assert channels == 2, "channels must be 2 for 48khz encodec model" self.model_name = model_name self.submodel_name = submodel_name self.sample_rate = sample_rate + self.channels = channels self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') self.verbose = verbose self.audio_load_worker = audio_load_worker @@ -197,11 +206,23 @@ def __get_model(self, model_name="vggish", use_pca=False, use_activation=False): device=self.device) self.model.load_ckpt(model_path) + # encodec + elif model_name == "encodec": + # choose the right model based on sample_rate + # weights are loaded from the encodec repo: https://github.com/facebookresearch/encodec/ + if self.sample_rate == 24000: + self.model = EncodecModel.encodec_model_24khz() + elif self.sample_rate == 48000: + self.model = EncodecModel.encodec_model_48khz() + # 24kbps is the max bandwidth supported by both versions + # these models use 32 residual quantizers + self.model.set_target_bandwidth(24.0) + self.model.eval() def get_embeddings(self, x, sr): """ - Get embeddings using VGGish model. + Get embeddings using VGGish, PANN, CLAP or EnCodec models. Params: -- x : a list of np.ndarray audio samples -- sr : Sampling rate, if x is a list of audio samples. Default value is 16000. @@ -219,8 +240,39 @@ def get_embeddings(self, x, sr): elif self.model_name == "clap": audio = torch.tensor(audio).float().unsqueeze(0) embd = self.model.get_audio_embedding_from_data(audio, use_tensor=True) - - if self.device == torch.device('cuda'): + elif self.model_name == "encodec": + # add two dimensions + audio = torch.tensor(audio).float().unsqueeze(0).unsqueeze(0) + # if SAMPLE_RATE is 48000, we need to make audio stereo + if self.model.sample_rate == 48000: + if audio.shape[-1] != 2: + print( + "[Frechet Audio Distance] Audio is mono, converting to stereo for 48khz model..." + ) + audio = torch.cat((audio, audio), dim=1) + else: + # transpose to (batch, channels, samples) + audio = audio[:, 0].transpose(1, 2) + + if self.verbose: + print( + "[Frechet Audio Distance] Audio shape: {}".format( + audio.shape + ) + ) + + with torch.no_grad(): + # encodec embedding (before quantization) + embd = self.model.encoder(audio) + embd = embd.squeeze(0) + + if self.verbose: + print( + "[Frechet Audio Distance] Embedding shape: {}".format( + embd.shape + ) + ) + if self.device == torch.device("cuda"): embd = embd.cpu() if torch.is_tensor(embd): @@ -309,8 +361,8 @@ def update(*a): for fname in os.listdir(dir): res = pool.apply_async( load_audio_task, - args=(os.path.join(dir, fname), self.sample_rate, dtype), - callback=update + args=(os.path.join(dir, fname), self.sample_rate, self.channels, dtype), + callback=update, ) task_results.append(res) pool.close() diff --git a/pyproject.toml b/pyproject.toml index fddf827..2755621 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ 'laion_clap', 'transformers<=4.30.2', 'torchaudio', + 'encodec', ] [project.urls] diff --git a/requirements.txt b/requirements.txt index baaf3a5..e2b49c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ torchvision laion_clap transformers<=4.30.2 -torchaudio \ No newline at end of file +torchaudio +encodec \ No newline at end of file diff --git a/test/test_all.ipynb b/test/test_all.ipynb index 4251b6e..5a8f197 100644 --- a/test/test_all.ipynb +++ b/test/test_all.ipynb @@ -594,6 +594,104 @@ "shutil.rmtree(\"test2\")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### EnCodec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# EnCodec is a model trained as a neural codec, that is, it is trained to compress audio into a latent space and then reconstruct it.\n", + "# One is able to obtain high quality reconstruction from the generated embeddings.\n", + "# It encodes 1 second of audio into 75 embeddings of 128 dimensions each.\n", + "SAMPLE_RATE = 24000\n", + "LENGTH_IN_SECONDS = 1\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/encodec\",\n", + " model_name=\"encodec\",\n", + " # submodel_name=\"music_speech_audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " # use_pca=True, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " # enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# EnCodec's 48kHz version maintains the embedding size, so compression is doubled.\n", + "# The model available for 48kHz audio expects stereo audio, so the input audio must have 2 channels.\n", + "SAMPLE_RATE = 48000\n", + "LENGTH_IN_SECONDS = 1\n", + "\n", + "frechet = FrechetAudioDistance(\n", + " ckpt_dir=\"../checkpoints/encodec\",\n", + " model_name=\"encodec\",\n", + " # submodel_name=\"music_speech_audioset\", # for CLAP only\n", + " sample_rate=SAMPLE_RATE,\n", + " channels=2,\n", + " # use_pca=True, # for VGGish only\n", + " # use_activation=False, # for VGGish only\n", + " verbose=True,\n", + " audio_load_worker=8,\n", + " # enable_fusion=False, # for CLAP only\n", + ")\n", + "\n", + "for traget, count, param in [(\"background\", 10, None), (\"test1\", 5, 0.0001), (\"test2\", 5, 0.00001)]:\n", + " os.makedirs(traget, exist_ok=True)\n", + " frequencies = np.linspace(100, 1000, count).tolist()\n", + " for freq in frequencies:\n", + " samples = gen_sine_wave(freq, LENGTH_IN_SECONDS, SAMPLE_RATE, param=param)\n", + " filename = os.path.join(traget, \"sin_%.0f.wav\" % freq)\n", + " # make audio stereo\n", + " samples = np.stack([samples, samples], axis=1)\n", + "\n", + " print(\"Creating: %s with %i samples.\" % (filename, samples.shape[0]))\n", + " sf.write(filename, samples, SAMPLE_RATE, \"PCM_24\")\n", + "\n", + "fad_score = frechet.score(\"background\", \"test1\")\n", + "print(\"FAD score test 1: %.8f\" % fad_score)\n", + "\n", + "fad_score = frechet.score(\"background\", \"test2\")\n", + "print(\"FAD score test 2: %.8f\" % fad_score)\n", + "\n", + "shutil.rmtree(\"background\")\n", + "shutil.rmtree(\"test1\")\n", + "shutil.rmtree(\"test2\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -618,7 +716,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.12" + "version": "3.10.13" }, "orig_nbformat": 4 },