Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Major Refactor #14

Merged
merged 8 commits into from
Oct 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
*.pyc
testdir
testscript.py

# Packages
*.egg
Expand Down
2 changes: 1 addition & 1 deletion elpis/datasets/clean_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def clean_text(
Returns:
The cleaned text
"""
words = text.lower().split()
words = text.upper().split()

if words_to_remove is not None:
words = filter(lambda word: word not in words_to_remove, words)
Expand Down
76 changes: 45 additions & 31 deletions elpis/datasets/dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass, field, fields
from functools import cached_property, reduce
from itertools import chain, groupby
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple

from elpis.models import ElanOptions

Expand Down Expand Up @@ -69,6 +71,9 @@ class Dataset:
cleaning_options: CleaningOptions
elan_options: Optional[ElanOptions]

def __post_init__(self):
self.files = sorted(self.files)

def is_empty(self) -> bool:
"""Returns true iff the dataset contains no files."""
return len(self.files) == 0
Expand All @@ -82,15 +87,29 @@ def is_valid(self) -> bool:
return (
not self.is_empty()
and len(self.files) % 2 == 0
and len(self.mismatched_files()) == 0
and len(self.colliding_files()) == 0
and len(self.mismatched_files) == 0
and len(self.colliding_files) == 0
)

@staticmethod
def is_audio(file: Path) -> bool:
return file.suffix == ".wav"

@staticmethod
def is_transcript(file: Path) -> bool:
return file.suffix in TRANSCRIPTION_EXTENSIONS

@staticmethod
def corresponding_audio_name(transcript_file: Path) -> Path:
"""Gets the corresponding audio file name for a given transcript file."""
return Path(transcript_file).parent / (transcript_file.stem + ".wav")

@property
def transcript_files(self) -> Iterable[Path]:
"""Returns an iterable of all transcription files within the dataset."""
return filter(Dataset.is_transcript, self.files)

@cached_property
def mismatched_files(self) -> Set[Path]:
"""Returns the list of transcript files with no corresponding
audio and vice versa.
Expand All @@ -101,18 +120,19 @@ def mismatched_files(self) -> Set[Path]:
Returns:
A list of the mismatched file names.
"""
transcripts_with_audio = set(
filter(
lambda file: Dataset.corresponding_audio_name(file) in self.files,
self._transcript_files(),
)
)
matched_files = transcripts_with_audio | set(
Dataset.corresponding_audio_name(file) for file in transcripts_with_audio
)
grouped_by_stems = groupby(self.files, lambda path: path.stem)

def mismatches(files: Iterable[Path]) -> list[Path]:
files = list(files)
has_audio = any(Dataset.is_audio(file) for file in files)
has_transcript = any(Dataset.is_transcript(file) for file in files)
return [] if has_transcript == has_audio else files

return set(self.files).difference(matched_files)
groups = (mismatches(g) for _, g in grouped_by_stems)
result = set(chain.from_iterable(groups))
return result

@cached_property
def colliding_files(self) -> Set[Path]:
"""Returns the list of transcript file names that collide.

Expand All @@ -122,19 +142,14 @@ def colliding_files(self) -> Set[Path]:
Returns:
A list of the colliding file names.
"""
grouped_by_stems = groupby(self.transcript_files, lambda path: path.stem)

def would_collide(transcript_file: Path) -> bool:
other_files = self._transcript_files().difference({transcript_file})
other_file_names = map(lambda file: Path(file).stem, other_files)
return Path(transcript_file).stem in other_file_names

return set(filter(would_collide, self._transcript_files()))
def collisions(files: Iterable[Path]) -> list[Path]:
files = list(files)
return files if len(files) >= 2 else []

def _transcript_files(self) -> Set[Path]:
"""Returns a set of all transcription files within the dataset."""
return set(
filter(lambda file: file.suffix in TRANSCRIPTION_EXTENSIONS, self.files)
)
collision_groups = (collisions(g) for _, g in grouped_by_stems)
return set(chain.from_iterable(collision_groups))

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> Dataset:
Expand All @@ -155,25 +170,24 @@ def from_dict(cls, data: Dict[str, Any]) -> Dataset:

@property
def valid_transcriptions(self):
return (
self._transcript_files()
.difference(self.mismatched_files())
.difference(self.colliding_files())
is_valid = lambda path: path not in (
self.mismatched_files | self.colliding_files
)
return filter(is_valid, self.transcript_files)

def to_batches(self) -> List[ProcessingBatch]:
def to_batches(self) -> Iterable[ProcessingBatch]:
"""Converts a valid dataset to a list of processing jobs, matching
transcript and audio files.
"""
return [
return (
ProcessingBatch(
transcription_file=transcription_file,
audio_file=self.corresponding_audio_name(transcription_file),
cleaning_options=self.cleaning_options,
elan_options=self.elan_options,
)
for transcription_file in self.valid_transcriptions
]
)

def to_dict(self) -> Dict[str, Any]:
result = {
Expand Down
4 changes: 0 additions & 4 deletions elpis/datasets/extract_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,6 @@ def extract_elan_annotations(
A list of the annotations contained for the supplied data. Returns an
empty list if the given selection isn't found.
"""
logger.info(
f"processing eaf {elan_file_path} using {selection_type}: {selection_data}"
)

match selection_type:
case ElanTierSelector.NAME:
return get_annotations_by_tier_name(elan_file_path, selection_data)
Expand Down
2 changes: 0 additions & 2 deletions elpis/datasets/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from pathlib import Path
from typing import Iterable, List, Tuple

from loguru import logger

import elpis.utils.audio as audio
from elpis.datasets.clean_text import clean_text
from elpis.datasets.dataset import CleaningOptions, ProcessingBatch
Expand Down
130 changes: 79 additions & 51 deletions elpis/datasets/processing.py
Original file line number Diff line number Diff line change
@@ -1,105 +1,133 @@
import os
from itertools import chain
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List

import numpy as np
from datasets import Audio, DatasetDict, load_dataset
from loguru import logger
from transformers import Wav2Vec2Processor
from transformers import AutoFeatureExtractor, AutoTokenizer

PROCESSOR_COUNT = 4
AUDIO_COLUMN = "audio"
SAMPLING_RATE = 16_000
from elpis.models.job import Job

LOGGING_TRANSCRIPT_SAMPLE = 2


def create_dataset(
dataset_path: Path, cache_dir: Optional[Path] = None, test_size: float = 0.2
job: Job,
test_size: float = 0.2,
) -> DatasetDict:
"""Creates a dataset with test/train splits from the data within a given
directory.

Parameters:
dataset_path: The path to the unprocessed dataset files.
cache_dir: The path to save the processed dataset.
job: The training job to run.
test_size: The percentage of the dataset to allocate as the test set.

Returns:
A dataset dictionary with test and train splits.
"""
dataset_path = Path(job.data_args.dataset_name_or_path)
if not dataset_path.is_dir():
raise ValueError(
f"Attempting to create local dataset from non-existent "
f"directory: {dataset_path}."
)

transcript_files = [
str(dataset_path / file)
for file in os.listdir(dataset_path)
if (dataset_path / file).suffix == ".json"
]
logger.debug(f"Transcript file paths sample: {transcript_files[:4]}")

# Annoying hack
if cache_dir is not None:
cache_dir = str(cache_dir) # type: ignore
logger.debug(
f"Transcript file paths sample: {transcript_files[:LOGGING_TRANSCRIPT_SAMPLE]}"
)

dataset = load_dataset("json", data_files=transcript_files, cache_dir=cache_dir) # type: ignore
dataset = load_dataset("json", data_files=transcript_files, cache_dir=job.model_args.cache_dir) # type: ignore

# Convert the audio file name column into the matching audio data
dataset = dataset.rename_column("audio_file", AUDIO_COLUMN)
logger.debug(f"Dataset audio file paths sample: {dataset['train'][AUDIO_COLUMN][:4]}") # type: ignore
audio_column = job.data_args.audio_column_name
dataset = dataset.rename_column("audio_file", audio_column)
logger.debug(f"Dataset audio file paths sample: {dataset['train'][audio_column][:LOGGING_TRANSCRIPT_SAMPLE]}") # type: ignore

def resolve_audio_path(row: Dict[str, Any]) -> Dict[str, Any]:
# Forcefully resolve to same dir as dataset.
path = dataset_path / Path(row[AUDIO_COLUMN]).name
row[AUDIO_COLUMN] = str(path)
path = dataset_path / Path(row[audio_column]).name
row[audio_column] = str(path.absolute())
return row

dataset = dataset.map(resolve_audio_path)
logger.debug(f"Dataset audio file paths post-resolution: {dataset['train'][AUDIO_COLUMN][:4]}") # type: ignore
dataset = dataset.cast_column(AUDIO_COLUMN, Audio(sampling_rate=SAMPLING_RATE))
logger.debug(f"Dataset audio file paths post-resolution: {dataset['train'][audio_column][:LOGGING_TRANSCRIPT_SAMPLE]}") # type: ignore

return dataset["train"].train_test_split(test_size=test_size) # type: ignore
dataset = dataset["train"].train_test_split(test_size=test_size, seed=job.training_args.seed) # type: ignore
# rename test to eval
dataset["eval"] = dataset["test"]
dataset.pop("test")

return dataset

def prepare_dataset(dataset: DatasetDict, processor: Wav2Vec2Processor) -> DatasetDict:
"""Runs some preprocessing over the given dataset.

TODO: I'm going to be honest, I have no idea what this does, and need some
smart ML knight in shining armour to write a propert description.
def prepare_dataset(
job: Job,
tokenizer: AutoTokenizer,
feature_extractor: AutoFeatureExtractor,
dataset: DatasetDict,
) -> DatasetDict:
"""Runs some preprocessing over the given dataset.

Parameters:
dataset: The dataset to apply the preprocessing
dataset: The dataset on which to apply the preprocessing
processor: The processor to apply over the dataset
"""

logger.debug(f"Dataset pre prep: {dataset}")
logger.debug(f"Dataset[train] pre prep: {dataset['train']['transcript'][0]}")
logger.debug(
f'Input array shape:, {np.asarray(dataset["train"][0]["audio"]["array"]).shape}'
# Load the audio data and resample if necessary.
dataset = dataset.cast_column(
job.data_args.audio_column_name,
Audio(sampling_rate=feature_extractor.sampling_rate), # type: ignore
)
logger.debug(f'Sampling rate:, {dataset["train"][0]["audio"]["sampling_rate"]}')
logger.debug(f"Tokenizer vocab: {processor.tokenizer.vocab}") # type: ignore

def _prepare_dataset(batch: Dict) -> Dict[str, List]:
# Also from https://huggingface.co/blog/fine-tune-xlsr-wav2vec2
audio = batch["audio"]

batch["input_values"] = processor(
audio = batch[job.data_args.audio_column_name]
inputs = feature_extractor( # type: ignore
audio["array"], sampling_rate=audio["sampling_rate"]
).input_values[0]
)

batch["input_values"] = inputs.input_values[0]
batch["input_length"] = len(batch["input_values"])

batch["labels"] = processor(text=batch["transcript"]).input_ids
# encode targets
additional_kwargs = {}
phoneme_language = job.data_args.phoneme_language
if phoneme_language is not None:
additional_kwargs["phonemizer_lang"] = phoneme_language

batch["labels"] = tokenizer(batch[job.data_args.text_column_name], **additional_kwargs).input_ids # type: ignore
return batch

columns = dataset.column_names.values()
# flatten and make unique between datasets
columns_to_remove = list(set(chain.from_iterable(columns)))

dataset = dataset.map(
_prepare_dataset,
remove_columns=columns_to_remove,
num_proc=PROCESSOR_COUNT,
max_input_length = (
job.data_args.max_duration_in_seconds * feature_extractor.sampling_rate # type: ignore
)
min_input_length = (
job.data_args.min_duration_in_seconds * feature_extractor.sampling_rate # type: ignore
)

logger.debug(f"Dataset post prep: {dataset}")
logger.debug(f"Training labels: {dataset['train']['labels'][0]}")
# logger.debug(f"Training inputs: {dataset['train']['input_values'][0]}")
def is_audio_in_length_range(length: int):
return length >= min_input_length and length <= max_input_length

with job.training_args.main_process_first(desc="dataset map preprocessing"):
worker_count = job.data_args.preprocessing_num_workers
dataset = dataset.map(
_prepare_dataset,
remove_columns=next(iter(dataset.values())).column_names,
num_proc=worker_count,
desc="preprocess datasets",
)

# filter data that is shorter than min_input_length
dataset = dataset.filter(
is_audio_in_length_range,
num_proc=worker_count,
input_columns=["input_length"],
)

logger.info(f"Test encoding labels: {dataset['train'][0]['labels']}")

return dataset
Loading
Loading