Skip to content

Commit

Permalink
Merge pull request #35 from descriptinc/ps/batch_sampler
Browse files Browse the repository at this point in the history
Adding a BatchSampler where you can change the batch size.
  • Loading branch information
pseeth authored Apr 7, 2022
2 parents e09cb2f + bcd7fc6 commit f6325ab
Show file tree
Hide file tree
Showing 12 changed files with 114 additions and 50 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.2"
__version__ = "0.3.3"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
2 changes: 2 additions & 0 deletions audiotools/core/audio_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,13 @@ def batch(
)
# Concatenate along the batch dimension
audio_data = torch.cat([x.audio_data for x in audio_signals], dim=0)
audio_paths = [x.path_to_input_file for x in audio_signals]

batched_signal = cls(
audio_data,
sample_rate=audio_signals[0].sample_rate,
)
batched_signal.path_to_input_file = audio_paths
return batched_signal

# I/O
Expand Down
66 changes: 38 additions & 28 deletions audiotools/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@
from multiprocessing import Manager
from typing import List

import torch
from flatten_dict import flatten
from flatten_dict import unflatten
from torch.utils.data import BatchSampler as _BatchSampler
from torch.utils.data import SequentialSampler
from torch.utils.data.distributed import DistributedSampler

Expand All @@ -13,10 +11,38 @@

# We need to set SHARED_KEYS statically, with no relationship to the
# BaseDataset object, or we'll hit RecursionErrors in the lookup.
SHARED_KEYS = ["duration", "shared_transform", "check_transform", "sample_rate"]
SHARED_KEYS = [
"duration",
"shared_transform",
"check_transform",
"sample_rate",
"batch_size",
]


class BaseDataset:
class SharedMixin:
def __getattribute__(self, name: str):
# Look up the name in SHARED_KEYS (see above). If it's there,
# return it from the dictionary that is kept in shared memory.
# Otherwise, do the normal __getattribute__. This line only
# runs if the key is in SHARED_KEYS.
if name in SHARED_KEYS:
return self.shared_dict[name]
else:
return super().__getattribute__(name)

def __setattr__(self, name, value):
# Look up the name in SHARED_KEYS (see above). If it's there
# set the value in the dictionary accordingly, so that it the other
# dataset replicas know about it. Otherwise, do the normal
# __setattr__. This line only runs if the key is in SHARED_KEYS.
if name in SHARED_KEYS:
self.shared_dict[name] = value
else:
super().__setattr__(name, value)


class BaseDataset(SharedMixin):
"""This BaseDataset class adds all the necessary logic so that there is
a dictionary that is shared across processes when working with a
DataLoader with num_workers > 0. It adds an attribute called
Expand Down Expand Up @@ -65,26 +91,6 @@ def transform(self, value):
self.shared_transform = value
self.check_transform = True

def __getattribute__(self, name: str):
# Look up the name in SHARED_KEYS (see above). If it's there,
# return it from the dictionary that is kept in shared memory.
# Otherwise, do the normal __getattribute__. This line only
# runs if the key is in SHARED_KEYS.
if name in SHARED_KEYS:
return self.shared_dict[name]
else:
return super().__getattribute__(name)

def __setattr__(self, name, value):
# Look up the name in SHARED_KEYS (see above). If it's there
# set the value in the dictionary accordingly, so that it the other
# dataset replicas know about it. Otherwise, do the normal
# __setattr__. This line only runs if the key is in SHARED_KEYS.
if name in SHARED_KEYS:
self.shared_dict[name] = value
else:
super().__setattr__(name, value)

def __len__(self):
return self.length

Expand Down Expand Up @@ -124,21 +130,25 @@ def __getitem__(self, idx):
state=state,
loudness_cutoff=self.loudness_cutoff,
)
if "loudness" in audio_info:
signal.metadata["file_loudness"] = float(audio_info["loudness"])
if self.mono:
signal = signal.to_mono()
signal = signal.resample(self.sample_rate)

# Instantiate the transform.
item = {"signal": signal}
item = {"idx": idx, "signal": signal}
if self.transform is not None:
item["transform_args"] = self.transform.instantiate(state, signal=signal)

return item


# Samplers
class BatchSampler(_BatchSampler, SharedMixin):
def __init__(self, sampler, batch_size: int, drop_last: bool = False):
self.shared_dict = Manager().dict()
super().__init__(sampler, batch_size, drop_last=drop_last)


class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
def __init__(self, dataset, start_idx=None, **kwargs):
super().__init__(dataset, **kwargs)
Expand Down
7 changes: 1 addition & 6 deletions audiotools/data/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from ..core import AudioSignal


def create_csv(audio_files: list, output_csv: Path, loudness=False):
def create_csv(audio_files: list, output_csv: Path):
"""Converts a folder of audio files to a CSV file.
Parameters
Expand All @@ -15,9 +15,6 @@ def create_csv(audio_files: list, output_csv: Path, loudness=False):
output_csv : Path
Output CSV, with each row containing the relative path of every file
to PATH_TO_DATA (defaults to None).
loudness : bool, optional
Loudness of every file is computed and put into CSV, if True,
by default False
"""
data_path = Path(os.getenv("PATH_TO_DATA", ""))

Expand All @@ -26,8 +23,6 @@ def create_csv(audio_files: list, output_csv: Path, loudness=False):
af = Path(af)
_info = {}
_info["path"] = af.relative_to(data_path)
if loudness:
_info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
info.append(_info)

with open(output_csv, "w") as f:
Expand Down
21 changes: 13 additions & 8 deletions audiotools/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,20 +431,19 @@ def _transform(self, signal, db):
class VolumeNorm(BaseTransform):
def __init__(
self,
db: float = -24,
db: tuple = ("const", -24),
name: str = None,
prob: float = 1.0,
):
super().__init__(name=name, prob=prob)

self.db = db

def _instantiate(self, state: RandomState, signal: AudioSignal = None):
return {"loudness": signal.metadata["file_loudness"]}
def _instantiate(self, state: RandomState):
return {"db": util.sample_from_dist(self.db, state)}

def _transform(self, signal, loudness):
db_change = self.db - loudness
return signal.volume_change(db_change)
def _transform(self, signal, db):
return signal.normalize(db)


class Silence(BaseTransform):
Expand Down Expand Up @@ -676,7 +675,13 @@ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
return {"window": AudioSignal(window, signal.sample_rate)}

def _transform(self, signal, window):
scale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
sscale[sscale == 0.0] = 1.0

out = signal.convolve(window)
out = out * scale / out.audio_data.abs().max(dim=-1, keepdim=True).values

oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
oscale[oscale == 0.0] = 1.0

out = out * (sscale / oscale)
return out
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.3.2",
version="0.3.3",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
4 changes: 2 additions & 2 deletions tests/audio/spk.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
path,loudness
tests/audio/spk/f10_script4_produced.wav,-16.5
path
tests/audio/spk/f10_script4_produced.wav
3 changes: 3 additions & 0 deletions tests/core/test_audio_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,10 @@ def test_batching():

signal_lengths = [x.signal_length for x in signals]
max_length = max(signal_lengths)
for i, x in enumerate(signals):
x.path_to_input_file = i
batched_signal = AudioSignal.batch(signals, resample=True, pad_signals=True)

assert batched_signal.signal_length == max_length
assert batched_signal.batch_size == batch_size
assert batched_signal.path_to_input_file == list(range(len(signals)))
41 changes: 41 additions & 0 deletions tests/data/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import audiotools
from audiotools import AudioSignal
from audiotools import data
from audiotools.data import transforms as tfm


Expand Down Expand Up @@ -111,6 +112,46 @@ def test_shared_transform():
assert num_succeeded >= 2


def test_batch_sampler():
for nw in (0, 1, 2):
dataset = audiotools.data.datasets.CSVDataset(
44100,
n_examples=100,
csv_files=["tests/audio/spk.csv"],
)

sampler = audiotools.datasets.BatchSampler(
audiotools.datasets.SequentialSampler(dataset), batch_size=1, drop_last=True
)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
num_workers=nw,
collate_fn=dataset.collate,
)

targets = {"bs": [1]}
observed = {"bs": []}

for new_bs in [1, 5, 10]:
dataloader.batch_sampler.batch_size = new_bs
targets["bs"].append(new_bs)

for batch in dataloader:
actual_bs = batch["signal"].batch_size
observed["bs"].append(actual_bs)

for k in targets:
_targets = [int(x) for x in targets[k]]
_observed = [int(x) for x in observed[k]]

num_succeeded = 0
for val in np.unique(_observed):
assert any([x == val for x in _targets])
num_succeeded += 1
assert num_succeeded >= 2


def test_csv_dataset():
transform = tfm.Compose(
[
Expand Down
4 changes: 1 addition & 3 deletions tests/data/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,4 @@

def test_create_csv():
with tempfile.NamedTemporaryFile(suffix=".csv") as f:
preprocess.create_csv(
find_audio("./tests/audio/spk", ext=["wav"]), f.name, loudness=True
)
preprocess.create_csv(find_audio("./tests/audio/spk", ext=["wav"]), f.name)
10 changes: 10 additions & 0 deletions tests/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,3 +393,13 @@ def test_nested_masking():
kwargs = batch["transform_args"]
with torch.no_grad():
output = dataset.transform(signal, **kwargs)


def test_smoothing_edge_case():
transform = tfm.Smoothing()
zeros = torch.zeros(1, 1, 44100)
signal = AudioSignal(zeros, 44100)
kwargs = transform.instantiate(0, signal)
output = transform(signal, **kwargs)

assert torch.allclose(output.audio_data, zeros)
2 changes: 1 addition & 1 deletion tests/regression/transforms/VolumeNorm.wav
Git LFS file not shown

0 comments on commit f6325ab

Please sign in to comment.