Skip to content

Saurabhbhati/DASS

Repository files navigation

DASS: Distilled Audio State Space Models Are Stronger and More Duration-Scalable Learners

Introduction

Overview of the DASS.

This repository contains cleaned up (ongoing) code (in PyTorch) for DASS model proposed in the SLT 2024 (accepted) DASS: Distilled Audio State Space Models Are Stronger and More Duration-Scalable Learners (Saurabhchand Bhati, Yuan Gong, Leonid Karlinsky, Hilde Kuehne, Rogerio Feris, James Glass).

DASS is the first state-space model that outperforms transformer-based audio classifiers such as AST, HTS-AT, and Audio-MAE. We use the AudioSet dataset for training and evaluation, where DASS achieves a new state-of-the-art mAP of 47.6 on AudioSet. We use knowledge distillation (KD) from a transformer based model, AST, to train the model as shown in Fig 1(a). KD is an integral part of DASS, helping DASS outperform the AST teacher, Fig 1(b).

Params Pretrain mAP
Transformer based models
AST 87M IN SL 45.9
HTS-AT 31M IN SL 47.1
PaSST IN SL 47.1
Audio-MAE 86M SSL 47.3
Concurrent SSM models
AuM 26M IN SL 39.7
Audio Mamba 40M IN SL 44.0
DASS-Small 30M IN SL 47.2
DASS-Medium 49M IN SL 47.6

The DASS model file is in src/models/ast_models.py, the recipes are in egs/audioset/run.sh, when you run run.sh, it will call /src/run.py, which will then call /src/dataloader.py and /src/traintest.py, which will then call /src/models/ast_models.py.

Getting Started

Step 1. Clone or download this repository and set it as the working directory, create a virtual environment and install the dependencies. The list of packages is in requirements.txt. For reference, the output of the pip freeze for the original environment is stored in pip_freeze.txt.

cd DASS/
conda create -n DASS 
conda activate DASS
pip install -r requirements.txt 
conda install cuda=12.1 -c nvidia # only required if cuda is not installed 
pip install -r requirements.txt 
cd kernels/selective_scan && pip install .

Step 2. Inference from the DASS model

import torchaudio
import torch 
from models.ast_models import DASS

def make_features(wav_name, mel_bins, target_length=1024):
    waveform, sr = torchaudio.load(wav_name)
    waveform = waveform - waveform.mean()
    fbank = torchaudio.compliance.kaldi.fbank(
        waveform, htk_compat=True, sample_frequency=sr, use_energy=False,
        window_type='hanning', num_mel_bins=mel_bins, dither=0.0,
        frame_shift=10)
    #
    n_frames = fbank.shape[0]
    p = target_length - n_frames
    if p > 0:
        m = torch.nn.ZeroPad2d((0, 0, 0, p))
        fbank = m(fbank)
    elif p < 0:
        fbank = fbank[0:target_length, :]
    #
    fbank = (fbank - (-4.2677393)) / (4.5689974 * 2)
    return fbank

audio_input = make_features('../babble_all.wav',128).unsqueeze(0)    #path to an audio file

audio_model = DASS(label_dim=527, imagenet_pretrain=False,
                    audioset_pretrain=False, model_size='small')

if not isinstance(audio_model, torch.nn.DataParallel):
    audio_model = torch.nn.DataParallel(audio_model).eval()

audio_model = audio_model.cuda()
audio_model.load_state_dict(torch.load('../pretrained_models/DASS_small.pth'), strict=True)

audio_output = audio_model(audio_input.cuda())

Training

The repository follows the training recipe from AST. Please follow the AST repository for detailed instruction on data preparation.

cd egs/audioset

model=DASS
model_size=small                                # small or medium
set=balanced                                    # balanced or full
imagenetpretrain=True                           # use imagenet pretrained model
knowledge_distillation=True                     # use AST as the teacher model 
dist_loss_type=kldiv                            # KD loss type kldiv, cosine etc.
lr=1e-4                                         # learning rate 
bash ./run.sh ${model} ${model_size} ${set} ${imagenetpretrain} ${knowledge_distillation} ${dist_loss_type} ${lr}

NIAH: Needle in the Audio Haystack

State-space models scale linearly with sequence length. Recent attempts have shown that state-space audio models require less memory and enjoy faster inference than transformer-based models. However, these models do not measure performance for longer input utterances (AuM, Audio Mamba).

To measure the performance at various lengths, we design a Needle In Audio Haystack (NIAH) task. In NIAH, we randomly put the 10-second needle in a randomly generated haystack of various lengths. The haystack could be zeros or fbank extracted from noise at various SNRs. The DASS model does not require any positional embeddings so it can easily applied to inputs of variable lengths. As seen in the figure above, DASS is significantly more duration robust than AST. DASS retains approximately half the performance even at 2.5 hours of haystack duration, Figure 1(c).

The main file for running NIAH is in src/run_NIAH.py which calls src/dataloader_niah.py for generating the haystack and fbank features for both needle and haystack.

model_size=small                                # small or medium
ckpt=../pretrained_models/DASS_small.pth        # path to the checkpoint to run NIAH on
input_max_len=5000                              # length of the haystack, 1000 features for 10 seconds
audio_insert_tstep=0.5                          # relative needle location, between 0, 1. 0: begining, 0.5: middle, 1: end 
niah_use_noise=False                            # use noise to construct the haystack, if False haystack is made from zeros
niah_noise_snr=10                               # SNR of the noise used for constructing noisy haystack, only used if niah_use_noise=True
niah_noise_type=white                           # white for Gaussian noise, babble for babble noise, only used if niah_use_noise=True 
python run_NIAH.py --model_size ${model_size} -ckpt ${ckpt} --input_max_len ${input_max_len} --audio_insert_tstep ${audio_insert_tstep}

Acknowledgment

This project is based on AST(paper, code), VMamba(paper, code) thanks for their excellant works. Please make sure to check them out.

Citation

@article{bhati2024dass,
  title={DASS: Distilled Audio State Space Models Are Stronger and More Duration-Scalable Learners},
  author={Bhati, Saurabhchand and Gong, Yuan and Karlinsky, Leonid and Kuehne, Hilde and Feris, Rogerio and Glass, James},
  journal={arXiv preprint arXiv:2407.04082},
  year={2024}
}

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Packages

No packages published