Skip to content

Commit

Permalink
Merge pull request #31 from descriptinc/ps/fix_rir
Browse files Browse the repository at this point in the history
Add option to `apply_ir` to re-use original phase.
  • Loading branch information
pseeth authored Mar 17, 2022
2 parents d2a01cf + 9c28001 commit 23b9ffe
Show file tree
Hide file tree
Showing 9 changed files with 46 additions and 24 deletions.
2 changes: 1 addition & 1 deletion audiotools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.8"
__version__ = "0.2.9"
from .core import AudioSignal, STFTParams, Meter, util
from . import metrics
from . import data
Expand Down
45 changes: 30 additions & 15 deletions audiotools/core/effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,17 +62,7 @@ def convolve(self, other, start_at_max=True):
This function uses FFTs to do the convolution.
"""
from .audio_signal import AudioSignal

if start_at_max:
idx = other.audio_data.abs().argmax(axis=-1)
weights = [
AudioSignal(
other.audio_data[i, ..., idx[i] :], sample_rate=other.sample_rate
)
for i in range(other.batch_size)
]
other = AudioSignal.batch(weights, pad_signals=True)
from . import AudioSignal

pad_len = self.signal_length - other.signal_length

Expand All @@ -81,21 +71,39 @@ def convolve(self, other, start_at_max=True):
else:
other.truncate_samples(self.signal_length)

other.audio_data /= torch.norm(
other.audio_data.clamp(min=1e-8), p=2, dim=-1, keepdim=True
)
if start_at_max:
# Use roll to rotate over the max for every item
# so that the impulse responses don't induce any
# delay.
idx = other.audio_data.abs().argmax(axis=-1)
irs = torch.zeros_like(other.audio_data)
for i in range(other.batch_size):
irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
other = AudioSignal(irs, other.sample_rate)

delta = torch.zeros_like(other.audio_data)
delta[..., 0] = 1

delta_fft = torch.fft.rfft(delta)
other_fft = torch.fft.rfft(other.audio_data)
self_fft = torch.fft.rfft(self.audio_data)

convolved_fft = other_fft * self_fft
convolved_audio = torch.fft.irfft(convolved_fft)

delta_convolved_fft = other_fft * delta_fft
delta_audio = torch.fft.irfft(delta_convolved_fft)

# Use the delta to rescale the audio exactly as needed.
delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
scale = 1 / delta_max.clamp(1e-5)
convolved_audio = convolved_audio * scale

self.audio_data = convolved_audio

return self

def apply_ir(self, ir, drr=None, ir_eq=None, rescale=True):
def apply_ir(self, ir, drr=None, ir_eq=None, use_original_phase=False):
if ir_eq is not None:
ir = ir.equalizer(ir_eq)
if drr is not None:
Expand All @@ -106,8 +114,15 @@ def apply_ir(self, ir, drr=None, ir_eq=None, rescale=True):

# Augment the impulse response to simulate microphone effects
# and with varying direct-to-reverberant ratio.
phase = self.phase
self.convolve(ir)

# Use the input phase
if use_original_phase:
self.stft()
self.stft_data = self.magnitude * torch.exp(1j * phase)
self.istft()

# Rescale to the input's amplitude
max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
Expand Down
7 changes: 6 additions & 1 deletion audiotools/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch
from flatten_dict import flatten
from flatten_dict import unflatten
from matplotlib import use
from numpy.random import RandomState

from ..core import AudioSignal
Expand Down Expand Up @@ -325,12 +326,14 @@ def __init__(
n_bands: int = 6,
name: str = None,
prob: float = 1.0,
use_original_phase: bool = False,
):
super().__init__(name=name, prob=prob)

self.drr = drr
self.eq_amount = eq_amount
self.n_bands = n_bands
self.use_original_phase = use_original_phase
self.audio_files = util.read_csv(csv_files)

def _instantiate(self, state: RandomState, signal: AudioSignal = None):
Expand Down Expand Up @@ -358,7 +361,9 @@ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
def _transform(self, signal, ir_signal, drr, eq):
# Clone ir_signal so that transform can be repeatedly applied
# to different signals with the same effect.
return signal.apply_ir(ir_signal.clone(), drr, eq)
return signal.apply_ir(
ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
)


class VolumeChange(BaseTransform):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="audiotools",
version="0.2.8",
version="0.2.9",
classifiers=[
"Intended Audience :: Developers",
"Intended Audience :: Education",
Expand Down
2 changes: 2 additions & 0 deletions tests/core/test_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,8 @@ def test_apply_ir():

assert np.allclose(ir.measure_drr().flatten(), 10)

output = spk.deepcopy().apply_ir(ir, drr=10, ir_eq=db, use_original_phase=True)


def test_ensure_max_of_audio():
spk = AudioSignal(torch.randn(1, 1, 44100), 44100)
Expand Down
6 changes: 3 additions & 3 deletions tests/data/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ def _compare_transform(transform_name, signal):

if regression_data.exists():
regression_signal = AudioSignal(regression_data)
regression_signal.loudness()
signal.loudness()
assert signal == regression_signal
assert torch.allclose(
signal.audio_data, regression_signal.audio_data, atol=1e-6
)
else:
signal.write(regression_data)

Expand Down
2 changes: 1 addition & 1 deletion tests/regression/transforms/Choose.wav
Git LFS file not shown
2 changes: 1 addition & 1 deletion tests/regression/transforms/Compose.wav
Git LFS file not shown
2 changes: 1 addition & 1 deletion tests/regression/transforms/RoomImpulseResponse.wav
Git LFS file not shown

0 comments on commit 23b9ffe

Please sign in to comment.