-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fast FullSubNet m=infinity #62
Comments
Thanks for your attention. Please refer to the implementation below. If you run this model, you may find its performance is similar to import torch
import torch.nn as nn
import torchaudio as audio
from torch.nn import functional
from torchinfo import summary
from audio_zen.model.base_model import BaseModel
from audio_zen.model.module.sequence_model import SequenceModel
class Model(BaseModel):
def __init__(
self,
look_ahead,
shrink_size,
sequence_model,
encoder_input_size,
num_mels,
noisy_input_num_neighbors,
encoder_output_num_neighbors,
norm_type="offline_laplace_norm",
weight_init=False,
):
"""
Simply FullSubNet.
Notes:
In this model, the encoder and bottleneck are corresponding to the fullband model and subband model, respectively.
"""
super().__init__()
assert sequence_model in (
"GRU",
"LSTM",
), f"{self.__class__.__name__} only support GRU and LSTM."
# Encoder
self.encoder = nn.Sequential(
SequenceModel(
input_size=64,
hidden_size=384,
output_size=0,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=None,
),
SequenceModel(
input_size=384,
hidden_size=257,
output_size=64,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function="ReLU",
),
)
# Mel filterbank
self.mel_scale = audio.transforms.MelScale(
n_mels=num_mels,
sample_rate=16000,
f_min=0,
f_max=8000,
n_stft=encoder_input_size,
)
self.decoder_lstm = nn.Sequential(
SequenceModel(
input_size=64,
hidden_size=512,
output_size=0,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=None,
),
SequenceModel(
input_size=512,
hidden_size=512,
output_size=257 * 2,
num_layers=1,
bidirectional=False,
sequence_model=sequence_model,
output_activate_function=None,
),
)
self.look_ahead = look_ahead
self.norm = self.norm_wrapper(norm_type)
self.num_mels = num_mels
self.noisy_input_num_neighbors = noisy_input_num_neighbors
self.enc_output_num_neighbors = encoder_output_num_neighbors
self.shrink_size = shrink_size
if weight_init:
self.apply(self.weight_init)
# fmt: off
def forward(self, mix_mag):
"""
Args:
mix_mag: noisy magnitude spectrogram
Returns:
The real part and imag part of the enhanced spectrogram
Shapes:
noisy_mag: [B, 1, F, T]
return: [B, 2, F, T]
"""
assert mix_mag.dim() == 4
mix_mag = functional.pad(mix_mag, [0, self.look_ahead]) # Pad the look ahead
batch_size, num_channels, num_freqs, num_frames = mix_mag.size()
assert num_channels == 1, f"{self.__class__.__name__} takes a mag feature as inputs."
# Mel filtering
mix_mel_mag = self.mel_scale(mix_mag) # [B, C, F_mel, T]
_, _, num_freqs_mel, _ = mix_mel_mag.shape
# Encoder - Fullband Model
enc_input = self.norm(mix_mel_mag).reshape(batch_size, -1, num_frames)
enc_output = self.encoder(enc_input).reshape(batch_size, num_channels, -1, num_frames) # [B, C, F, T]
dec_input = enc_output.reshape(batch_size, -1, num_frames)
decoder_lstm_output = self.decoder_lstm(dec_input) # [B * C, F * 2, T]
# Decoder - Fullband Linear Model
dec_output = decoder_lstm_output.reshape(batch_size, 2, num_freqs, num_frames)
# Output
output = dec_output[:, :, :, self.look_ahead:]
return output
if __name__ == "__main__":
with torch.no_grad():
noisy_mag = torch.rand(1, 1, 257, 63)
model = Model(
look_ahead=2,
shrink_size=16,
sequence_model="LSTM",
encoder_input_size=257,
num_mels=64,
noisy_input_num_neighbors=5,
encoder_output_num_neighbors=0,
)
output = model(noisy_mag)
print(summary(model, (1, 1, 257, 63), device="cpu")) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Thank you very much for your excellent work and I'm replicating the results from your paper(Fast FullSubNet). I have reproduced the result of m= 2,8, but when m=infinity was reproduced, I could not get the result described in your paper for the moment. If it's all right with you, I hope you could provide me the model structure with m=infinity about the sub-model is removed. Thank you very much for your help and i am looking forward to your reply!
The text was updated successfully, but these errors were encountered: