From 4bb215a1a2be649685ea71837619442a641fb760 Mon Sep 17 00:00:00 2001 From: pseeth Date: Thu, 31 Mar 2022 12:25:46 -0700 Subject: [PATCH 1/9] Adding a BatchSampler where you can change the batch size. --- audiotools/__init__.py | 2 +- audiotools/data/datasets.py | 62 ++++++++++++++++++++++--------------- setup.py | 2 +- tests/data/test_datasets.py | 41 ++++++++++++++++++++++++ 4 files changed, 80 insertions(+), 27 deletions(-) diff --git a/audiotools/__init__.py b/audiotools/__init__.py index 0f043876..14fa0e22 100644 --- a/audiotools/__init__.py +++ b/audiotools/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.3.2" +__version__ = "0.3.3" from .core import AudioSignal, STFTParams, Meter, util from . import metrics from . import data diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py index 99bb4d28..51a58677 100644 --- a/audiotools/data/datasets.py +++ b/audiotools/data/datasets.py @@ -2,9 +2,7 @@ from multiprocessing import Manager from typing import List -import torch -from flatten_dict import flatten -from flatten_dict import unflatten +from torch.utils.data import BatchSampler as _BatchSampler from torch.utils.data import SequentialSampler from torch.utils.data.distributed import DistributedSampler @@ -13,10 +11,38 @@ # We need to set SHARED_KEYS statically, with no relationship to the # BaseDataset object, or we'll hit RecursionErrors in the lookup. -SHARED_KEYS = ["duration", "shared_transform", "check_transform", "sample_rate"] +SHARED_KEYS = [ + "duration", + "shared_transform", + "check_transform", + "sample_rate", + "batch_size", +] -class BaseDataset: +class SharedMixin: + def __getattribute__(self, name: str): + # Look up the name in SHARED_KEYS (see above). If it's there, + # return it from the dictionary that is kept in shared memory. + # Otherwise, do the normal __getattribute__. This line only + # runs if the key is in SHARED_KEYS. + if name in SHARED_KEYS: + return self.shared_dict[name] + else: + return super().__getattribute__(name) + + def __setattr__(self, name, value): + # Look up the name in SHARED_KEYS (see above). If it's there + # set the value in the dictionary accordingly, so that it the other + # dataset replicas know about it. Otherwise, do the normal + # __setattr__. This line only runs if the key is in SHARED_KEYS. + if name in SHARED_KEYS: + self.shared_dict[name] = value + else: + super().__setattr__(name, value) + + +class BaseDataset(SharedMixin): """This BaseDataset class adds all the necessary logic so that there is a dictionary that is shared across processes when working with a DataLoader with num_workers > 0. It adds an attribute called @@ -65,26 +91,6 @@ def transform(self, value): self.shared_transform = value self.check_transform = True - def __getattribute__(self, name: str): - # Look up the name in SHARED_KEYS (see above). If it's there, - # return it from the dictionary that is kept in shared memory. - # Otherwise, do the normal __getattribute__. This line only - # runs if the key is in SHARED_KEYS. - if name in SHARED_KEYS: - return self.shared_dict[name] - else: - return super().__getattribute__(name) - - def __setattr__(self, name, value): - # Look up the name in SHARED_KEYS (see above). If it's there - # set the value in the dictionary accordingly, so that it the other - # dataset replicas know about it. Otherwise, do the normal - # __setattr__. This line only runs if the key is in SHARED_KEYS. - if name in SHARED_KEYS: - self.shared_dict[name] = value - else: - super().__setattr__(name, value) - def __len__(self): return self.length @@ -139,6 +145,12 @@ def __getitem__(self, idx): # Samplers +class BatchSampler(_BatchSampler, SharedMixin): + def __init__(self, sampler, batch_size: int, drop_last: bool = False): + self.shared_dict = Manager().dict() + super().__init__(sampler, batch_size, drop_last=drop_last) + + class ResumableDistributedSampler(DistributedSampler): # pragma: no cover def __init__(self, dataset, start_idx=None, **kwargs): super().__init__(dataset, **kwargs) diff --git a/setup.py b/setup.py index aadf0df1..ce02a074 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="audiotools", - version="0.3.2", + version="0.3.3", classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Education", diff --git a/tests/data/test_datasets.py b/tests/data/test_datasets.py index 1d909df9..83ea6a0d 100644 --- a/tests/data/test_datasets.py +++ b/tests/data/test_datasets.py @@ -3,6 +3,7 @@ import audiotools from audiotools import AudioSignal +from audiotools import data from audiotools.data import transforms as tfm @@ -111,6 +112,46 @@ def test_shared_transform(): assert num_succeeded >= 2 +def test_batch_sampler(): + for nw in (0, 1, 2): + dataset = audiotools.data.datasets.CSVDataset( + 44100, + n_examples=100, + csv_files=["tests/audio/spk.csv"], + ) + + sampler = audiotools.datasets.BatchSampler( + audiotools.datasets.SequentialSampler(dataset), batch_size=1, drop_last=True + ) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_sampler=sampler, + num_workers=nw, + collate_fn=dataset.collate, + ) + + targets = {"bs": [1]} + observed = {"bs": []} + + for new_bs in [1, 5, 10]: + dataloader.batch_sampler.batch_size = new_bs + targets["bs"].append(new_bs) + + for batch in dataloader: + actual_bs = batch["signal"].batch_size + observed["bs"].append(actual_bs) + + for k in targets: + _targets = [int(x) for x in targets[k]] + _observed = [int(x) for x in observed[k]] + + num_succeeded = 0 + for val in np.unique(_observed): + assert any([x == val for x in _targets]) + num_succeeded += 1 + assert num_succeeded >= 2 + + def test_csv_dataset(): transform = tfm.Compose( [ From 12bc336df8e3a39d395d1e92910311dd6f2f6305 Mon Sep 17 00:00:00 2001 From: lgestin Date: Mon, 4 Apr 2022 13:33:37 -0400 Subject: [PATCH 2/9] dataset getitem returns idx in dict --- audiotools/data/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py index 51a58677..f6ef96a2 100644 --- a/audiotools/data/datasets.py +++ b/audiotools/data/datasets.py @@ -137,7 +137,7 @@ def __getitem__(self, idx): signal = signal.resample(self.sample_rate) # Instantiate the transform. - item = {"signal": signal} + item = {"idx": idx, "signal": signal} if self.transform is not None: item["transform_args"] = self.transform.instantiate(state, signal=signal) From 1827dc5f422427fe73dc83c97248f12188696111 Mon Sep 17 00:00:00 2001 From: pseeth Date: Wed, 6 Apr 2022 11:50:21 -0700 Subject: [PATCH 3/9] Removing file loudness stuff. --- audiotools/data/datasets.py | 2 -- audiotools/data/preprocess.py | 7 +------ audiotools/data/transforms.py | 11 +++++------ tests/audio/spk.csv | 4 ++-- tests/regression/transforms/VolumeNorm.wav | 2 +- 5 files changed, 9 insertions(+), 17 deletions(-) diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py index f6ef96a2..39abc5ba 100644 --- a/audiotools/data/datasets.py +++ b/audiotools/data/datasets.py @@ -130,8 +130,6 @@ def __getitem__(self, idx): state=state, loudness_cutoff=self.loudness_cutoff, ) - if "loudness" in audio_info: - signal.metadata["file_loudness"] = float(audio_info["loudness"]) if self.mono: signal = signal.to_mono() signal = signal.resample(self.sample_rate) diff --git a/audiotools/data/preprocess.py b/audiotools/data/preprocess.py index 9d0f126b..7f072829 100644 --- a/audiotools/data/preprocess.py +++ b/audiotools/data/preprocess.py @@ -5,7 +5,7 @@ from ..core import AudioSignal -def create_csv(audio_files: list, output_csv: Path, loudness=False): +def create_csv(audio_files: list, output_csv: Path): """Converts a folder of audio files to a CSV file. Parameters @@ -15,9 +15,6 @@ def create_csv(audio_files: list, output_csv: Path, loudness=False): output_csv : Path Output CSV, with each row containing the relative path of every file to PATH_TO_DATA (defaults to None). - loudness : bool, optional - Loudness of every file is computed and put into CSV, if True, - by default False """ data_path = Path(os.getenv("PATH_TO_DATA", "")) @@ -26,8 +23,6 @@ def create_csv(audio_files: list, output_csv: Path, loudness=False): af = Path(af) _info = {} _info["path"] = af.relative_to(data_path) - if loudness: - _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() info.append(_info) with open(output_csv, "w") as f: diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py index 5f4022e9..46a40ef6 100644 --- a/audiotools/data/transforms.py +++ b/audiotools/data/transforms.py @@ -431,7 +431,7 @@ def _transform(self, signal, db): class VolumeNorm(BaseTransform): def __init__( self, - db: float = -24, + db: tuple = ("const", -24), name: str = None, prob: float = 1.0, ): @@ -439,12 +439,11 @@ def __init__( self.db = db - def _instantiate(self, state: RandomState, signal: AudioSignal = None): - return {"loudness": signal.metadata["file_loudness"]} + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} - def _transform(self, signal, loudness): - db_change = self.db - loudness - return signal.volume_change(db_change) + def _transform(self, signal, db): + return signal.normalize(db) class Silence(BaseTransform): diff --git a/tests/audio/spk.csv b/tests/audio/spk.csv index 4bb964c2..2d7b5728 100644 --- a/tests/audio/spk.csv +++ b/tests/audio/spk.csv @@ -1,2 +1,2 @@ -path,loudness -tests/audio/spk/f10_script4_produced.wav,-16.5 +path +tests/audio/spk/f10_script4_produced.wav diff --git a/tests/regression/transforms/VolumeNorm.wav b/tests/regression/transforms/VolumeNorm.wav index e963c5ee..2c538385 100644 --- a/tests/regression/transforms/VolumeNorm.wav +++ b/tests/regression/transforms/VolumeNorm.wav @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a1b3416db578cb541a5cb360c22e2b0a36fd3a57d644b7ce8b804108a3f84510 +oid sha256:86538fbb9a4b749caf414888ea8d98989720e3129505f7c9e5067143aeb5509e size 352858 From c12f17bfdc8a5b1088b1f1849685aebac8bdefe9 Mon Sep 17 00:00:00 2001 From: pseeth Date: Wed, 6 Apr 2022 12:20:45 -0700 Subject: [PATCH 4/9] Collecting input paths. --- audiotools/core/audio_signal.py | 2 ++ tests/core/test_audio_signal.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/audiotools/core/audio_signal.py b/audiotools/core/audio_signal.py index b6d52817..87de8a26 100644 --- a/audiotools/core/audio_signal.py +++ b/audiotools/core/audio_signal.py @@ -154,11 +154,13 @@ def batch( ) # Concatenate along the batch dimension audio_data = torch.cat([x.audio_data for x in audio_signals], dim=0) + audio_paths = [x.path_to_input_file for x in audio_signals] batched_signal = cls( audio_data, sample_rate=audio_signals[0].sample_rate, ) + batched_signal.path_to_input_file = audio_paths return batched_signal # I/O diff --git a/tests/core/test_audio_signal.py b/tests/core/test_audio_signal.py index 1b1337f1..483014ee 100644 --- a/tests/core/test_audio_signal.py +++ b/tests/core/test_audio_signal.py @@ -507,7 +507,10 @@ def test_batching(): signal_lengths = [x.signal_length for x in signals] max_length = max(signal_lengths) + for i, x in enumerate(signals): + x.path_to_input_file = i batched_signal = AudioSignal.batch(signals, resample=True, pad_signals=True) assert batched_signal.signal_length == max_length assert batched_signal.batch_size == batch_size + assert batched_signal.path_to_input_file == list(range(len(signals))) From 4b67678f3a3f1b3257ae591fdc7efe289dcf5505 Mon Sep 17 00:00:00 2001 From: lgestin Date: Wed, 6 Apr 2022 16:34:22 -0400 Subject: [PATCH 5/9] fix smoothing scaling when signal is 0 --- audiotools/data/transforms.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py index 46a40ef6..f0d05776 100644 --- a/audiotools/data/transforms.py +++ b/audiotools/data/transforms.py @@ -675,7 +675,15 @@ def _instantiate(self, state: RandomState, signal: AudioSignal = None): return {"window": AudioSignal(window, signal.sample_rate)} def _transform(self, signal, window): - scale = signal.audio_data.abs().max(dim=-1, keepdim=True).values + sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values + if sscale == 0.0: + sscale = 1.0 + out = signal.convolve(window) - out = out * scale / out.audio_data.abs().max(dim=-1, keepdim=True).values + + oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values + if oscale == 0.0: + oscale = 1.0 + + out = out * (sscale / oscale) return out From 12103ed790d6802a24703683e02a58a8f36f637d Mon Sep 17 00:00:00 2001 From: lgestin Date: Wed, 6 Apr 2022 16:45:38 -0400 Subject: [PATCH 6/9] made the fix work for batch input --- audiotools/data/transforms.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py index f0d05776..6b5e72eb 100644 --- a/audiotools/data/transforms.py +++ b/audiotools/data/transforms.py @@ -676,14 +676,12 @@ def _instantiate(self, state: RandomState, signal: AudioSignal = None): def _transform(self, signal, window): sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values - if sscale == 0.0: - sscale = 1.0 - + if torch.any(sscale == 0.0): + sscale[sscale == 0.0] = 1.0 out = signal.convolve(window) oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values - if oscale == 0.0: - oscale = 1.0 - + if torch.any(oscale == 0.0): + oscale[oscale == 0.0] = 1.0 out = out * (sscale / oscale) return out From 849c844ef06a60e17a711355fae3e8d56a38c121 Mon Sep 17 00:00:00 2001 From: pseeth Date: Wed, 6 Apr 2022 18:05:20 -0700 Subject: [PATCH 7/9] Fixing create_csv test. --- tests/data/test_preprocess.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/data/test_preprocess.py b/tests/data/test_preprocess.py index 36977432..7b68bfb2 100644 --- a/tests/data/test_preprocess.py +++ b/tests/data/test_preprocess.py @@ -7,6 +7,4 @@ def test_create_csv(): with tempfile.NamedTemporaryFile(suffix=".csv") as f: - preprocess.create_csv( - find_audio("./tests/audio/spk", ext=["wav"]), f.name, loudness=True - ) + preprocess.create_csv(find_audio("./tests/audio/spk", ext=["wav"]), f.name) From 1cf912e285bc0b9be7a4bd506c9939ef3e88a14f Mon Sep 17 00:00:00 2001 From: pseeth Date: Wed, 6 Apr 2022 18:14:18 -0700 Subject: [PATCH 8/9] Catching edge case. --- tests/data/test_transforms.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/data/test_transforms.py b/tests/data/test_transforms.py index a897a6ac..db9df20a 100644 --- a/tests/data/test_transforms.py +++ b/tests/data/test_transforms.py @@ -393,3 +393,13 @@ def test_nested_masking(): kwargs = batch["transform_args"] with torch.no_grad(): output = dataset.transform(signal, **kwargs) + + +def test_smoothing_edge_case(): + transform = tfm.Smoothing() + zeros = torch.zeros(1, 1, 44100) + signal = AudioSignal(zeros, 44100) + kwargs = transform.instantiate(0, signal) + output = transform(signal, **kwargs) + + assert torch.allclose(output.audio_data, zeros) From bcd7fc61f4e97c87169b1c9f7d6b35c7d2c6a0f9 Mon Sep 17 00:00:00 2001 From: lgestin Date: Thu, 7 Apr 2022 14:11:13 -0400 Subject: [PATCH 9/9] remove unnecessary check --- audiotools/data/transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py index 6b5e72eb..80b617db 100644 --- a/audiotools/data/transforms.py +++ b/audiotools/data/transforms.py @@ -676,12 +676,12 @@ def _instantiate(self, state: RandomState, signal: AudioSignal = None): def _transform(self, signal, window): sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values - if torch.any(sscale == 0.0): - sscale[sscale == 0.0] = 1.0 + sscale[sscale == 0.0] = 1.0 + out = signal.convolve(window) oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values - if torch.any(oscale == 0.0): - oscale[oscale == 0.0] = 1.0 + oscale[oscale == 0.0] = 1.0 + out = out * (sscale / oscale) return out