Skip to content

Commit

Permalink
Add Metrics (#6)
Browse files Browse the repository at this point in the history
* Update dependencies in poetry and add accelerate package

* Add metrics in properly.
---------

Co-authored-by: Ben Foley <[email protected]>
  • Loading branch information
harrykeightley and benfoley authored Sep 8, 2023
1 parent 3fb405b commit cff43a0
Show file tree
Hide file tree
Showing 6 changed files with 2,016 additions and 302 deletions.
26 changes: 20 additions & 6 deletions elpis/datasets/processing.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Optional

from datasets import Audio, DatasetDict, load_dataset
from loguru import logger
from transformers import Wav2Vec2Processor

PROCESSOR_COUNT = 4
Expand Down Expand Up @@ -60,21 +62,33 @@ def prepare_dataset(dataset: DatasetDict, processor: Wav2Vec2Processor) -> Datas
processor: The processor to apply over the dataset
"""

def prepare_dataset(batch: Dict) -> Dict[str, List]:
logger.debug(f"Dataset pre prep: {dataset}")
logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript']}")
logger.debug(f"Tokenizer vocab: {processor.tokenizer.vocab}") # type: ignore

def _prepare_dataset(batch: Dict) -> Dict[str, List]:
# Also from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
audio = batch["audio"]

batch["input_values"] = processor(
audio["array"], sampling_rate=audio["sampling_rate"]
).input_values[0]
batch["input_length"] = len(batch["input_values"])

with processor.as_target_processor():
batch["labels"] = processor(batch["transcript"]).input_ids
batch["labels"] = processor(text=batch["transcript"]).input_ids

return batch

return dataset.map(
prepare_dataset,
remove_columns=dataset.column_names["train"],
column_names = [dataset.column_names[key] for key in dataset.column_names.keys()]
# flatten
columns_to_remove = list(chain.from_iterable(column_names))

dataset = dataset.map(
_prepare_dataset,
remove_columns=columns_to_remove,
num_proc=PROCESSOR_COUNT,
)

logger.debug(f"Dataset post prep: {dataset}")
logger.debug(f"Training labels: {dataset['train']['labels']}")
return dataset
5 changes: 4 additions & 1 deletion elpis/trainer/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
from dataclasses import dataclass, fields
from enum import Enum
from pathlib import Path
from typing import Any, Dict
from typing import Any, Dict, Tuple

import torch
from transformers import TrainingArguments

BASE_MODEL = "facebook/wav2vec2-base-960h"
SAMPLING_RATE = 16_000
METRICS = ("wer", "cer")


class TrainingStatus(Enum):
Expand Down Expand Up @@ -52,6 +53,7 @@ class TrainingJob:
status: TrainingStatus = TrainingStatus.WAITING
base_model: str = BASE_MODEL
sampling_rate: int = SAMPLING_RATE
metrics: Tuple[str, ...] = METRICS

def to_training_args(self, output_dir: Path, **kwargs) -> TrainingArguments:
return TrainingArguments(
Expand Down Expand Up @@ -86,6 +88,7 @@ def from_dict(data: Dict[str, Any]) -> TrainingJob:
status=TrainingStatus(data.get("status", TrainingStatus.WAITING)),
base_model=data.get("base_model", BASE_MODEL),
sampling_rate=data.get("sampling_rate", SAMPLING_RATE),
metrics=data.get("metrics", METRICS),
)

def to_dict(self) -> Dict[str, Any]:
Expand Down
44 changes: 44 additions & 0 deletions elpis/trainer/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from typing import Callable, Dict, Optional, Sequence

import evaluate
import numpy as np
from loguru import logger
from transformers import EvalPrediction, Wav2Vec2Processor


def create_metrics(
metric_names: Sequence[str], processor: Wav2Vec2Processor
) -> Optional[Callable[[EvalPrediction], Dict]]:
# Handle metrics
if len(metric_names) == 0:
return

# Note: was using evaluate.combine but was having many unexpected errors.
metrics = {name: evaluate.load(name) for name in metric_names}

def compute_metrics(pred: EvalPrediction) -> Dict:
# taken from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
pred_logits = pred.predictions

pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id # type: ignore

# Taken from: https://discuss.huggingface.co/t/code-review-compute-metrics-for-wer-with-wav2vec2processorwithlm/16841/3
if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":
pred_str = processor.batch_decode(pred_logits).text
else:
pred_ids = np.argmax(pred_logits, axis=-1)
pred_str = processor.batch_decode(pred_ids)

# We do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

logger.debug(f"METRICS->pred: {pred_str} label:{label_str}")

result = {
name: metric.compute(predictions=pred_str, references=label_str)
for name, metric in metrics.items()
}
logger.debug(f"Metrics Result: {result}")
return result

return compute_metrics
11 changes: 7 additions & 4 deletions elpis/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from contextlib import nullcontext
from pathlib import Path
from typing import Optional
from typing import Dict, Optional

from loguru import logger
from transformers import AutoModelForCTC, AutoProcessor, Trainer
from tokenizers import Tokenizer
from transformers import AutoModelForCTC, AutoProcessor, EvalPrediction, Trainer

from elpis.datasets import create_dataset, prepare_dataset
from elpis.trainer.data_collator import DataCollatorCTCWithPadding
from elpis.trainer.job import TrainingJob
from elpis.trainer.metrics import create_metrics
from elpis.trainer.utils import log_to_file


Expand All @@ -18,7 +20,7 @@ def train(
cache_dir: Optional[Path] = None,
log_file: Optional[Path] = None,
) -> Path:
"""Trains a model for use in transcription.
"""Fine-tunes a model for use in transcription.
Parameters:
job: Info about the training job, e.g. training options.
Expand Down Expand Up @@ -61,6 +63,7 @@ def train(
eval_dataset=dataset["test"], # type: ignore
tokenizer=processor.feature_extractor,
data_collator=data_collator,
compute_metrics=create_metrics(job.metrics, processor),
)

logger.info(f"Begin training model...")
Expand All @@ -74,9 +77,9 @@ def train(
logger.info(f"Model written to disk.")

metrics = trainer.evaluate()
logger.info("==== Metrics ====")
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)
logger.info("==== Metrics ====")
logger.info(metrics)

return output_dir
Loading

0 comments on commit cff43a0

Please sign in to comment.