Skip to content

Commit

Permalink
added wav2vec2 embeddings
Browse files Browse the repository at this point in the history
  • Loading branch information
FBurkhardt committed Dec 3, 2021
1 parent 7791ea7 commit 7a826cc
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ Here's [an overview on the ini-file options](./ini_file.md)

### Features
* Classifiers: XGB, XGR, SVM, SVR, MLP
* Feature extractors: opensmile, openXBOW BoAW, TRILL embeddings
* Feature extractors: opensmile, openXBOW BoAW, TRILL embeddings, Wav2vec2 embeddings, ...
* Feature scaling
* Label encoding
* Binning (continuous to categorical)
Expand Down
4 changes: 3 additions & 1 deletion ini_file.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
* set = eGeMAPSv02 *(features set)*
* level = functionals *(or lld: feature level)*
* **spectra**: Melspecs for convolutional networks
* **trill**: [TRILL embeddings](https://ai.googleblog.com/2020/06/improving-speech-representations-and.html)
* **trill**: [TRILL embeddings](https://ai.googleblog.com/2020/06/improving-speech-representations-and.html) from Google
* **wav2vec**: [Wav2vec2 embeddings](https://huggingface.co/facebook/wav2vec2-large-robust-ft-swbd-300h) from facebook
* **model** = = *path to the wav2vec2 model folder*
* **mld**: [mid-level-descriptors](http://www.essv.de/paper.php?id=447)
* mld = *path to the mld sources folder*
* min_syls = *minimum number of syllables*
Expand Down
1 change: 1 addition & 0 deletions requirements_mld.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
scikit_posthocs
6 changes: 6 additions & 0 deletions requirements_wav2vec.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
audiofile
git+https://github.com/huggingface/datasets.git
git+https://github.com/huggingface/transformers.git
jiwer
torchaudio
librosa
8 changes: 8 additions & 0 deletions src/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ def extract_feats(self):
self.feats_test = TRILLset(f'{feats_name}_test', df_test)
self.feats_test.extract()
self.feats_test.filter()
elif feats_type=='wav2vec':
from feats_wav2vec2 import Wav2vec2
self.feats_train = Wav2vec2(f'{feats_name}_train', df_train)
self.feats_train.extract()
self.feats_train.filter()
self.feats_test = Wav2vec2(f'{feats_name}_test', df_test)
self.feats_test.extract()
self.feats_test.filter()
elif feats_type=='mld':
from feats_mld import MLD_set
self.feats_train = MLD_set(f'{feats_name}_train', df_train)
Expand Down
3 changes: 2 additions & 1 deletion src/feats_trill.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def extract(self):
for idx, file in enumerate(self.data_df.index):
emb = self.getEmbeddings(file)
emb_series[idx] = emb
self.util.debug(f'TRILL: {idx} of {length} done')
if idx%10==0:
self.util.debug(f'TRILL: {idx} of {length} done')
self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
self.df.to_pickle(storage)
try:
Expand Down
76 changes: 76 additions & 0 deletions src/feats_wav2vec2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# feats_wav2vec2.py

from util import Util
from featureset import Featureset
import os
import pandas as pd
import os
import glob_conf
import numpy as np
import transformers
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model
import torch

import audiofile

class Wav2vec2(Featureset):
"""Class to extract wav2vec2 embeddings (https://huggingface.co/facebook/wav2vec2-large-robust-ft-swbd-300h)"""

def __init__(self, name, data_df):
"""Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training"""
super().__init__(name, data_df)
model_path = self.util.config_val('FEATS', 'model', 'wav2vec2-large-robust-ft-swbd-300h')
self.device = self.util.config_val('MODEL', 'device', 'cpu')

# load model
self.processor = transformers.Wav2Vec2Processor.from_pretrained(model_path)
self.model = Wav2Vec2Model.from_pretrained(model_path).to(self.device)
self.model.eval()



def extract(self):
"""Extract the features or load them from disk if present."""
store = self.util.get_path('store')
storage = f'{store}{self.name}.pkl'
extract = self.util.config_val('FEATS', 'needs_feature_extraction', False)
if extract or not os.path.isfile(storage):
self.util.debug('extracting wav2vec2 embeddings, this might take a while...')
emb_series = pd.Series(index = self.data_df.index, dtype=object)
length = len(self.data_df.index)
for idx, file in enumerate(self.data_df.index):
emb = self.get_embeddings(file)
emb_series[idx] = emb
if idx%10==0:
self.util.debug(f'Wav2vec2: {idx} of {length} done')
self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index)
self.df.to_pickle(storage)
try:
glob_conf.config['DATA']['needs_feature_extraction'] = 'false'
except KeyError:
pass
else:
self.util.debug('reusing extracted wav2vec2 embeddings')
self.df = pd.read_pickle(storage)


def get_embeddings(self, audio_path):
r"""Extract embeddings from raw audio signal."""
signal, sampling_rate = audiofile.read(audio_path, always_2d=True)
with torch.no_grad():
# run through processor to normalize signal
# always returns a batch, so we just get the first entry
# then we put it on the device
y = self.processor(signal, sampling_rate=sampling_rate)
y = y['input_values'][0]
y = torch.from_numpy(y).to(self.device)

# run through model
# first entry contains hidden state
y = self.model(y)[0]

# pool result and convert to numpy
y = torch.mean(y, dim=1)
y = y.detach().cpu().numpy()

return y.flatten()

0 comments on commit 7a826cc

Please sign in to comment.