diff --git a/maha_tts/utils/audio.py b/maha_tts/utils/audio.py index 715b6b2..000c4ad 100644 --- a/maha_tts/utils/audio.py +++ b/maha_tts/utils/audio.py @@ -1,15 +1,11 @@ import torch import numpy as np -import librosa.util as librosa_util - from scipy.signal import get_window from scipy.io.wavfile import read -from maha_tts.config import config TACOTRON_MEL_MAX = 2.4 TACOTRON_MEL_MIN = -11.5130 - def denormalize_tacotron_mel(norm_mel): return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN @@ -18,66 +14,65 @@ def normalize_tacotron_mel(mel): return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1 -def get_mask_from_lengths(lengths, max_len=None): - if not max_len: - max_len = torch.max(lengths).item() - ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long) - mask = (ids < lengths.unsqueeze(1)).bool() - return mask +def get_mask(lengths, max_len=None): + """ + Generate a mask for sequences based on lengths. + Parameters: + - lengths: Torch tensor, lengths of sequences + - max_len: Optional, maximum length for padding -def get_mask(lengths, max_len=None): - if not max_len: - max_len = torch.max(lengths).item() - lens = torch.arange(max_len,) + Returns: + - Torch tensor, mask for sequences + """ + max_len = max_len or torch.max(lengths).item() + lens = torch.arange(max_len) mask = lens[:max_len].unsqueeze(0) < lengths.unsqueeze(1) return mask - - def dynamic_range_compression(x, C=1, clip_val=1e-5): """ - PARAMS - ------ - C: compression factor + Perform dynamic range compression on input tensor. + + Parameters: + - x: Torch tensor, input tensor + - C: Compression factor + - clip_val: Minimum value to clamp input tensor + + Returns: + - Torch tensor, compressed tensor """ return torch.log(torch.clamp(x, min=clip_val) * C) - def dynamic_range_decompression(x, C=1): """ - PARAMS - ------ - C: compression factor used to compress + Perform dynamic range decompression on input tensor. + + Parameters: + - x: Torch tensor, input tensor + - C: Compression factor used for compression + + Returns: + - Torch tensor, decompressed tensor """ return torch.exp(x) / C - def window_sumsquare(window, n_frames, hop_length=200, win_length=800, n_fft=800, dtype=np.float32, norm=None): """ - # from librosa 0.6 Compute the sum-square envelope of a window function at a given hop length. - This is used to estimate modulation effects induced by windowing - observations in short-time fourier transforms. - Parameters - ---------- - window : string, tuple, number, callable, or list-like - Window specification, as in `get_window` - n_frames : int > 0 - The number of analysis frames - hop_length : int > 0 - The number of samples to advance between frames - win_length : [optional] - The length of the window function. By default, this matches `n_fft`. - n_fft : int > 0 - The length of each analysis frame. - dtype : np.dtype - The data type of the output - Returns - ------- - wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` - The sum-squared envelope of the window function + + Parameters: + - window: String, tuple, number, callable, or list-like; window specification + - n_frames: Int, number of analysis frames + - hop_length: Int, number of samples to advance between frames + - win_length: Int, length of the window function + - n_fft: Int, length of each analysis frame + - dtype: Numpy data type of the output + - norm: Normalization type for the window function + + Returns: + - Numpy array, sum-squared envelope of the window function """ if win_length is None: win_length = n_fft @@ -87,8 +82,8 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800, # Compute the squared window at the desired length win_sq = get_window(window, win_length, fftbins=True) - win_sq = librosa_util.normalize(win_sq, norm=norm)**2 - win_sq = librosa_util.pad_center(win_sq, size=n_fft) + win_sq = np.square(librosa.util.normalize(win_sq, norm=norm)) + win_sq = librosa.util.pad_center(win_sq, size=n_fft) # Fill the envelope for i in range(n_frames): @@ -97,13 +92,21 @@ def window_sumsquare(window, n_frames, hop_length=200, win_length=800, return x def load_wav_to_torch(full_path): - sampling_rate, data = read(full_path,) - return torch.FloatTensor(data), sampling_rate + """ + Load WAV file into Torch tensor. + Parameters: + - full_path: String, path to the WAV file + Returns: + - Torch tensor, audio data + - Int, sampling rate + """ + sampling_rate, data = read(full_path) + return torch.FloatTensor(data), sampling_rate if __name__ == "__main__": lens = torch.tensor([2, 3, 7, 5, 4]) - mask = get_mask(lens) + mask = get_mask(lens) print(mask) - print(mask.shape) \ No newline at end of file + print(mask.shape)