Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast FullSubNet m=infinity #62

Open
SEMLLYCAT opened this issue Mar 6, 2023 · 1 comment
Open

Fast FullSubNet m=infinity #62

SEMLLYCAT opened this issue Mar 6, 2023 · 1 comment

Comments

@SEMLLYCAT
Copy link

Thank you very much for your excellent work and I'm replicating the results from your paper(Fast FullSubNet). I have reproduced the result of m= 2,8, but when m=infinity was reproduced, I could not get the result described in your paper for the moment. If it's all right with you, I hope you could provide me the model structure with m=infinity about the sub-model is removed. Thank you very much for your help and i am looking forward to your reply!

@haoxiangsnr
Copy link
Member

haoxiangsnr commented Mar 7, 2023

Thanks for your attention.

Please refer to the implementation below. If you run this model, you may find its performance is similar to $m=8$.

import torch
import torch.nn as nn
import torchaudio as audio
from torch.nn import functional
from torchinfo import summary

from audio_zen.model.base_model import BaseModel
from audio_zen.model.module.sequence_model import SequenceModel


class Model(BaseModel):
    def __init__(
        self,
        look_ahead,
        shrink_size,
        sequence_model,
        encoder_input_size,
        num_mels,
        noisy_input_num_neighbors,
        encoder_output_num_neighbors,
        norm_type="offline_laplace_norm",
        weight_init=False,
    ):
        """
        Simply FullSubNet.

        Notes:
            In this model, the encoder and bottleneck are corresponding to the fullband model and subband model, respectively.
        """
        super().__init__()
        assert sequence_model in (
            "GRU",
            "LSTM",
        ), f"{self.__class__.__name__} only support GRU and LSTM."

        # Encoder
        self.encoder = nn.Sequential(
            SequenceModel(
                input_size=64,
                hidden_size=384,
                output_size=0,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function=None,
            ),
            SequenceModel(
                input_size=384,
                hidden_size=257,
                output_size=64,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function="ReLU",
            ),
        )

        # Mel filterbank
        self.mel_scale = audio.transforms.MelScale(
            n_mels=num_mels,
            sample_rate=16000,
            f_min=0,
            f_max=8000,
            n_stft=encoder_input_size,
        )

        self.decoder_lstm = nn.Sequential(
            SequenceModel(
                input_size=64,
                hidden_size=512,
                output_size=0,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function=None,
            ),
            SequenceModel(
                input_size=512,
                hidden_size=512,
                output_size=257 * 2,
                num_layers=1,
                bidirectional=False,
                sequence_model=sequence_model,
                output_activate_function=None,
            ),
        )

        self.look_ahead = look_ahead
        self.norm = self.norm_wrapper(norm_type)
        self.num_mels = num_mels
        self.noisy_input_num_neighbors = noisy_input_num_neighbors
        self.enc_output_num_neighbors = encoder_output_num_neighbors
        self.shrink_size = shrink_size

        if weight_init:
            self.apply(self.weight_init)

    # fmt: off
    def forward(self, mix_mag):
        """
        Args:
            mix_mag: noisy magnitude spectrogram

        Returns:
            The real part and imag part of the enhanced spectrogram

        Shapes:
            noisy_mag: [B, 1, F, T]
            return: [B, 2, F, T]
        """
        assert mix_mag.dim() == 4
        mix_mag = functional.pad(mix_mag, [0, self.look_ahead])  # Pad the look ahead
        batch_size, num_channels, num_freqs, num_frames = mix_mag.size()
        assert num_channels == 1, f"{self.__class__.__name__} takes a mag feature as inputs."

        # Mel filtering
        mix_mel_mag = self.mel_scale(mix_mag)  # [B, C, F_mel, T]
        _, _, num_freqs_mel, _ = mix_mel_mag.shape

        # Encoder - Fullband Model
        enc_input = self.norm(mix_mel_mag).reshape(batch_size, -1, num_frames)
        enc_output = self.encoder(enc_input).reshape(batch_size, num_channels, -1, num_frames)  # [B, C, F, T]

        dec_input = enc_output.reshape(batch_size, -1, num_frames)
        decoder_lstm_output = self.decoder_lstm(dec_input)  # [B * C, F * 2, T]

        # Decoder - Fullband Linear Model
        dec_output = decoder_lstm_output.reshape(batch_size, 2, num_freqs, num_frames)

        # Output
        output = dec_output[:, :, :, self.look_ahead:]

        return output


if __name__ == "__main__":
    with torch.no_grad():
        noisy_mag = torch.rand(1, 1, 257, 63)
        model = Model(
            look_ahead=2,
            shrink_size=16,
            sequence_model="LSTM",
            encoder_input_size=257,
            num_mels=64,
            noisy_input_num_neighbors=5,
            encoder_output_num_neighbors=0,
        )
        output = model(noisy_mag)
        print(summary(model, (1, 1, 257, 63), device="cpu"))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants