Skip to content

Commit

Permalink
Build custom Tokenizer and custom Processor flows for wav2vec2 models. (
Browse files Browse the repository at this point in the history
#7)

* Add vocab Dataclass.
* Add customer processor function to trainer.
* Test that we properly generate output from the vocab list.
* Add both training and test sets to vocabulary for tokenizer.
* Update transformers version to fix empty transcription results.
  • Loading branch information
harrykeightley authored Sep 19, 2023
1 parent cff43a0 commit db11fe1
Show file tree
Hide file tree
Showing 11 changed files with 411 additions and 134 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*.pyc
testdir

# Packages
*.egg
Expand Down
16 changes: 11 additions & 5 deletions elpis/datasets/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional

import numpy as np
from datasets import Audio, DatasetDict, load_dataset
from loguru import logger
from transformers import Wav2Vec2Processor
Expand Down Expand Up @@ -63,7 +64,11 @@ def prepare_dataset(dataset: DatasetDict, processor: Wav2Vec2Processor) -> Datas
"""

logger.debug(f"Dataset pre prep: {dataset}")
logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript']}")
logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript'][0]}")
logger.debug(
f'Input array shape:, {np.asarray(dataset["train"][0]["audio"]["array"]).shape}'
)
logger.debug(f'Sampling rate:, {dataset["train"][0]["audio"]["sampling_rate"]}')
logger.debug(f"Tokenizer vocab: {processor.tokenizer.vocab}") # type: ignore

def _prepare_dataset(batch: Dict) -> Dict[str, List]:
Expand All @@ -79,9 +84,9 @@ def _prepare_dataset(batch: Dict) -> Dict[str, List]:

return batch

column_names = [dataset.column_names[key] for key in dataset.column_names.keys()]
# flatten
columns_to_remove = list(chain.from_iterable(column_names))
columns = dataset.column_names.values()
# flatten and make unique between datasets
columns_to_remove = list(set(chain.from_iterable(columns)))

dataset = dataset.map(
_prepare_dataset,
Expand All @@ -90,5 +95,6 @@ def _prepare_dataset(batch: Dict) -> Dict[str, List]:
)

logger.debug(f"Dataset post prep: {dataset}")
logger.debug(f"Training labels: {dataset['train']['labels']}")
logger.debug(f"Training labels: {dataset['train']['labels'][0]}")
# logger.debug(f"Training inputs: {dataset['train']['input_values'][0]}")
return dataset
3 changes: 2 additions & 1 deletion elpis/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from elpis.models.annotation import Annotation
from elpis.models.elan_options import ElanOptions, ElanTierSelector
from elpis.models.vocab import VOCAB_FILE, Vocab

__all__ = ["Annotation", "ElanOptions", "ElanTierSelector"]
__all__ = ["Annotation", "ElanOptions", "ElanTierSelector", "Vocab", "VOCAB_FILE"]
65 changes: 65 additions & 0 deletions elpis/models/vocab.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
from dataclasses import dataclass
from functools import reduce
from pathlib import Path
from typing import Any, Dict, Iterable, Optional, Set

VOCAB_FILE = "vocab.json"


@dataclass
class Vocab:
"""A class which represents a dictionary of encountered tokens in a dataset."""

vocab: Dict[str, int]

@property
def symbols(self) -> Set[str]:
return set(self.vocab.keys())

def merge(self, other: "Vocab") -> "Vocab":
"""Creates a new Vocab which includes all symbols in the merged two."""
vocab = self.symbols | other.symbols
return Vocab.from_set(vocab)

def save(self, path: Path) -> None:
"""Saves the vocab to the supplied path.
If the path is a folder, saves as vocab.json, within it.
"""
if path.is_dir():
path /= VOCAB_FILE

with open(path, "w") as out:
json.dump(self.vocab, out)

def add(self, char: str) -> None:
"""Adds a new character into the vocab."""
if char in self.vocab:
return

self.vocab[char] = len(self.vocab)

def replace(self, original: str, replacement: str) -> None:
"""Replaces the supplied character mapping in the vocab."""
if original not in self.vocab:
return

self.vocab[replacement] = self.vocab[original]
self.vocab.pop(original)

@classmethod
def from_set(cls, symbols: Set[str]) -> "Vocab":
"""Builds a vocab from a set of symbols."""
vocab = {symbol: index for index, symbol in enumerate(sorted(symbols))}
return cls(vocab=vocab)

@classmethod
def from_strings(cls, texts: Iterable[str]) -> "Vocab":
"""Builds an vocab from a iterable text collection."""

def reducer(result: Set[str], text: str) -> Set[str]:
return result | set(text)

symbols = reduce(reducer, texts, set())
return cls.from_set(symbols)
2 changes: 1 addition & 1 deletion elpis/trainer/data_collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class DataCollatorCTCWithPadding:
def __call__(
self, features: List[Dict[str, Union[List[int], torch.Tensor]]]
) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lenghts and need
# split inputs and labels since they have to be of different lengths and need
# different padding methods
input_features = [
{"input_values": feature["input_values"]} for feature in features
Expand Down
14 changes: 7 additions & 7 deletions elpis/trainer/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ class TrainingOptions:
learning_rate: float = 1e-4
min_duration: int = 0
max_duration: int = 60
word_delimiter_token: str = " "
test_size: float = 0.2
freeze_feature_extractor: bool = False
# word_delimiter_token: str = " " # Note: This might interfere with the tokenizer?
test_size: float = 0.2 # TODO: link with dataset?
freeze_feature_extractor: bool = True

@staticmethod
def from_dict(data: Dict[str, Any]) -> "TrainingOptions":
Expand Down Expand Up @@ -68,10 +68,10 @@ def to_training_args(self, output_dir: Path, **kwargs) -> TrainingArguments:
gradient_checkpointing=True,
learning_rate=self.options.learning_rate,
weight_decay=0.005,
save_steps=500,
eval_steps=500,
logging_steps=500,
warmup_steps=1000,
save_steps=400,
eval_steps=400,
logging_steps=400,
warmup_steps=500,
save_total_limit=2,
overwrite_output_dir=True,
do_train=True,
Expand Down
73 changes: 67 additions & 6 deletions elpis/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,70 @@
from contextlib import nullcontext
from pathlib import Path
from typing import Dict, Optional
from typing import Optional

from datasets import DatasetDict
from loguru import logger
from tokenizers import Tokenizer
from transformers import AutoModelForCTC, AutoProcessor, EvalPrediction, Trainer
from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoModelForCTC,
AutoProcessor,
AutoTokenizer,
Trainer,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2Processor,
)

from elpis.datasets import create_dataset, prepare_dataset
from elpis.models.vocab import VOCAB_FILE, Vocab
from elpis.trainer.data_collator import DataCollatorCTCWithPadding
from elpis.trainer.job import TrainingJob
from elpis.trainer.metrics import create_metrics
from elpis.trainer.utils import log_to_file


def create_processor(
job: TrainingJob,
output_dir: Path,
dataset: DatasetDict,
cache_dir: Optional[Path],
unk_token="[UNK]",
pad_token="[PAD]",
word_delimiter_token="|",
) -> Wav2Vec2Processor:
config = AutoConfig.from_pretrained(job.base_model)
tokenizer_type = config.model_type if config.tokenizer_class is None else None
config = config if config.tokenizer_class is not None else None

# Build up a vocab from the dataset.
train_vocab = Vocab.from_strings(dataset["train"]["transcript"])
test_vocab = Vocab.from_strings(dataset["test"]["transcript"])
vocab = train_vocab.merge(test_vocab)

vocab.add(unk_token)
vocab.add(pad_token)
vocab.replace(" ", word_delimiter_token) # feels a little restrictive?
logger.info(f"Vocab: {vocab.vocab}")
vocab.save(output_dir)

tokenizer = AutoTokenizer.from_pretrained(
output_dir,
config=config,
tokenizer_type=tokenizer_type,
unk_token=unk_token,
pad_token=pad_token,
word_delimiter_token=word_delimiter_token,
cache_dir=cache_dir,
)

feature_extractor = AutoFeatureExtractor.from_pretrained(
job.base_model, cache_dir=cache_dir
)

return Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)


def train(
job: TrainingJob,
output_dir: Path,
Expand All @@ -37,7 +89,7 @@ def train(
with context:
logger.info("Preparing Datasets...")
dataset = create_dataset(dataset_dir, cache_dir)
processor = AutoProcessor.from_pretrained(job.base_model, cache_dir=cache_dir)
processor = create_processor(job, output_dir, dataset, cache_dir)
dataset = prepare_dataset(dataset, processor)
logger.info("Finished Preparing Datasets")

Expand All @@ -46,7 +98,16 @@ def train(
job.base_model,
cache_dir=cache_dir,
ctc_loss_reduction="mean",
pad_token_id=processor.tokenizer.pad_token_id,
pad_token_id=processor.tokenizer.pad_token_id, # type: ignore
# Wav2vec2 specific hyperparams copied from docs.
attention_dropout=0.1,
hidden_dropout=0.1,
feat_proj_dropout=0.0,
mask_time_prob=0.05,
layerdrop=0.1,
vocab_size=len(processor.tokenizer), # type: ignore
# For Ash -> errors if below param not set.
ignore_mismatched_sizes=True,
)
logger.info("Downloaded model.")

Expand All @@ -61,7 +122,7 @@ def train(
args=job.to_training_args(output_dir),
train_dataset=dataset["train"], # type: ignore
eval_dataset=dataset["test"], # type: ignore
tokenizer=processor.feature_extractor,
tokenizer=processor.feature_extractor, # type: ignore
data_collator=data_collator,
compute_metrics=create_metrics(job.metrics, processor),
)
Expand Down
Loading

0 comments on commit db11fe1

Please sign in to comment.