-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
FBurkhardt
committed
Dec 3, 2021
1 parent
7791ea7
commit 7a826cc
Showing
7 changed files
with
97 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
scikit_posthocs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
audiofile | ||
git+https://github.com/huggingface/datasets.git | ||
git+https://github.com/huggingface/transformers.git | ||
jiwer | ||
torchaudio | ||
librosa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# feats_wav2vec2.py | ||
|
||
from util import Util | ||
from featureset import Featureset | ||
import os | ||
import pandas as pd | ||
import os | ||
import glob_conf | ||
import numpy as np | ||
import transformers | ||
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2Model | ||
import torch | ||
|
||
import audiofile | ||
|
||
class Wav2vec2(Featureset): | ||
"""Class to extract wav2vec2 embeddings (https://huggingface.co/facebook/wav2vec2-large-robust-ft-swbd-300h)""" | ||
|
||
def __init__(self, name, data_df): | ||
"""Constructor. is_train is needed to distinguish from test/dev sets, because they use the codebook from the training""" | ||
super().__init__(name, data_df) | ||
model_path = self.util.config_val('FEATS', 'model', 'wav2vec2-large-robust-ft-swbd-300h') | ||
self.device = self.util.config_val('MODEL', 'device', 'cpu') | ||
|
||
# load model | ||
self.processor = transformers.Wav2Vec2Processor.from_pretrained(model_path) | ||
self.model = Wav2Vec2Model.from_pretrained(model_path).to(self.device) | ||
self.model.eval() | ||
|
||
|
||
|
||
def extract(self): | ||
"""Extract the features or load them from disk if present.""" | ||
store = self.util.get_path('store') | ||
storage = f'{store}{self.name}.pkl' | ||
extract = self.util.config_val('FEATS', 'needs_feature_extraction', False) | ||
if extract or not os.path.isfile(storage): | ||
self.util.debug('extracting wav2vec2 embeddings, this might take a while...') | ||
emb_series = pd.Series(index = self.data_df.index, dtype=object) | ||
length = len(self.data_df.index) | ||
for idx, file in enumerate(self.data_df.index): | ||
emb = self.get_embeddings(file) | ||
emb_series[idx] = emb | ||
if idx%10==0: | ||
self.util.debug(f'Wav2vec2: {idx} of {length} done') | ||
self.df = pd.DataFrame(emb_series.values.tolist(), index=self.data_df.index) | ||
self.df.to_pickle(storage) | ||
try: | ||
glob_conf.config['DATA']['needs_feature_extraction'] = 'false' | ||
except KeyError: | ||
pass | ||
else: | ||
self.util.debug('reusing extracted wav2vec2 embeddings') | ||
self.df = pd.read_pickle(storage) | ||
|
||
|
||
def get_embeddings(self, audio_path): | ||
r"""Extract embeddings from raw audio signal.""" | ||
signal, sampling_rate = audiofile.read(audio_path, always_2d=True) | ||
with torch.no_grad(): | ||
# run through processor to normalize signal | ||
# always returns a batch, so we just get the first entry | ||
# then we put it on the device | ||
y = self.processor(signal, sampling_rate=sampling_rate) | ||
y = y['input_values'][0] | ||
y = torch.from_numpy(y).to(self.device) | ||
|
||
# run through model | ||
# first entry contains hidden state | ||
y = self.model(y)[0] | ||
|
||
# pool result and convert to numpy | ||
y = torch.mean(y, dim=1) | ||
y = y.detach().cpu().numpy() | ||
|
||
return y.flatten() |