Skip to content

Commit

Permalink
add training and batch evaluation scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
hkchengrex committed Dec 23, 2024
1 parent d294fb8 commit a758040
Show file tree
Hide file tree
Showing 56 changed files with 201,176 additions and 80 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ ext_weights/
ext_weights
.checkpoints/
.vscode/
training/example_output/

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
52 changes: 3 additions & 49 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,51 +83,7 @@ pip install -e .

The models will be downloaded automatically when you run the demo script. MD5 checksums are provided in `mmaudio/utils/download_utils.py`.
The models are also available at https://huggingface.co/hkchengrex/MMAudio/tree/main

| Model | Download link | File size |
| -------- | ------- | ------- |
| Flow prediction network, small 16kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_16k.pth" download="mmaudio_small_16k.pth">mmaudio_small_16k.pth</a> | 601M |
| Flow prediction network, small 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_small_44k.pth" download="mmaudio_small_44k.pth">mmaudio_small_44k.pth</a> | 601M |
| Flow prediction network, medium 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_medium_44k.pth" download="mmaudio_medium_44k.pth">mmaudio_medium_44k.pth</a> | 2.4G |
| Flow prediction network, large 44.1kHz | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k.pth" download="mmaudio_large_44k.pth">mmaudio_large_44k.pth</a> | 3.9G |
| Flow prediction network, large 44.1kHz, v2 **(recommended)** | <a href="https://huggingface.co/hkchengrex/MMAudio/resolve/main/weights/mmaudio_large_44k_v2.pth" download="mmaudio_large_44k_v2.pth">mmaudio_large_44k_v2.pth</a> | 3.9G |
| 16kHz VAE | <a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-16.pth">v1-16.pth</a> | 655M |
| 16kHz BigVGAN vocoder (from Make-An-Audio 2) |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/best_netG.pt">best_netG.pt</a> | 429M |
| 44.1kHz VAE |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/v1-44.pth">v1-44.pth</a> | 1.2G |
| Synchformer visual encoder |<a href="https://github.com/hkchengrex/MMAudio/releases/download/v0.1/synchformer_state_dict.pth">synchformer_state_dict.pth</a> | 907M |

To run the model, you need four components: a flow prediction network, visual feature extractors (Synchformer and CLIP, CLIP will be downloaded automatically), a VAE, and a vocoder. VAEs and vocoders are specific to the sampling rate (16kHz or 44.1kHz) and not model sizes.
The 44.1kHz vocoder will be downloaded automatically.

The expected directory structure (full):

```bash
MMAudio
├── ext_weights
│ ├── best_netG.pt
│ ├── synchformer_state_dict.pth
│ ├── v1-16.pth
│ └── v1-44.pth
├── weights
│ ├── mmaudio_small_16k.pth
│ ├── mmaudio_small_44k.pth
│ ├── mmaudio_medium_44k.pth
│ ├── mmaudio_large_44k.pth
│ └── mmaudio_large_44k_v2.pth
└── ...
```

The expected directory structure (minimal, for the recommended model only):

```bash
MMAudio
├── ext_weights
│ ├── synchformer_state_dict.pth
│ └── v1-44.pth
├── weights
│ └── mmaudio_large_44k_v2.pth
└── ...
```
See [MODELS.md](docs/MODELS.md) for more details.

## Demo

Expand Down Expand Up @@ -180,13 +136,11 @@ We believe all of these three limitations can be addressed with more high-qualit

## Training

Work in progress.
See [TRAINING.md](docs/TRAINING.md).

## Evaluation

You can access the precomputed results on VGGSound, AudioCaps, and MovieGen here: https://huggingface.co/datasets/hkchengrex/MMAudio-precomputed-results

We have shared our evaluation code here: https://github.com/hkchengrex/av-benchmark
See [EVAL.md](docs/EVAL.md).

## Training Datasets

Expand Down
110 changes: 110 additions & 0 deletions batch_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging
import os
from pathlib import Path

import hydra
import torch
import torch.distributed as distributed
import torchaudio
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from tqdm import tqdm

from mmaudio.data.data_setup import setup_eval_dataset
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate
from mmaudio.model.flow_matching import FlowMatching
from mmaudio.model.networks import MMAudio, get_my_mmaudio
from mmaudio.model.utils.features_utils import FeaturesUtils

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
log = logging.getLogger()


@torch.inference_mode()
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml')
def main(cfg: DictConfig):
device = 'cuda'
torch.cuda.set_device(local_rank)

if cfg.model not in all_model_cfg:
raise ValueError(f'Unknown model variant: {cfg.model}')
model: ModelConfig = all_model_cfg[cfg.model]
model.download_if_needed()
seq_cfg = model.seq_cfg

run_dir = Path(HydraConfig.get().run.dir)
if cfg.output_name is None:
output_dir = run_dir / cfg.dataset
else:
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}'
output_dir.mkdir(parents=True, exist_ok=True)

# load a pretrained model
seq_cfg.duration = cfg.duration_s
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval()
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True))
log.info(f'Loaded weights from {model.model_path}')
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len)
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}')
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}')
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}')

# misc setup
rng = torch.Generator(device=device)
rng.manual_seed(cfg.seed)
fm = FlowMatching(cfg.sampling.min_sigma,
inference_mode=cfg.sampling.method,
num_steps=cfg.sampling.num_steps)

feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
synchformer_ckpt=model.synchformer_ckpt,
enable_conditions=True,
mode=model.mode,
bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
need_vae_encoder=False)
feature_utils = feature_utils.to(device).eval()

if cfg.compile:
net.preprocess_conditions = torch.compile(net.preprocess_conditions)
net.predict_flow = torch.compile(net.predict_flow)
feature_utils.compile()

dataset, loader = setup_eval_dataset(cfg.dataset, cfg)

with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device):
for batch in tqdm(loader):
audios = generate(batch.get('clip_video', None),
batch.get('sync_video', None),
batch.get('caption', None),
feature_utils=feature_utils,
net=net,
fm=fm,
rng=rng,
cfg_strength=cfg.cfg_strength,
clip_batch_size_multiplier=64,
sync_batch_size_multiplier=64)
audios = audios.float().cpu()
names = batch['name']
for audio, name in zip(audios, names):
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate)


def distributed_setup():
distributed.init_process_group(backend="nccl")
local_rank = distributed.get_rank()
world_size = distributed.get_world_size()
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
return local_rank, world_size


if __name__ == '__main__':
distributed_setup()

main()

# clean-up
distributed.destroy_process_group()
Empty file added config/__init__.py
Empty file.
62 changes: 62 additions & 0 deletions config/base_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
defaults:
- data: base
- eval_data: base
- override hydra/job_logging: custom-simplest
- _self_

hydra:
run:
dir: ./output/${exp_id}
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra

enable_email: False

model: small_16k

exp_id: default
debug: False
cudnn_benchmark: True
compile: True
amp: True
weights: null
checkpoint: null
seed: 14159265
num_workers: 10 # per-GPU
pin_memory: False # set to True if your system can handle it, i.e., have enough memory

# NOTE: This DOSE NOT affect the model during inference in any way
# they are just for the dataloader to fill in the missing data in multi-modal loading
# to change the sequence length for the model, see networks.py
data_dim:
text_seq_len: 77
clip_dim: 1024
sync_dim: 768
text_dim: 1024

# ema configuration
ema:
enable: True
sigma_rels: [0.05, 0.1]
update_every: 1
checkpoint_every: 5_000
checkpoint_folder: ./output/${exp_id}/ema_ckpts
default_output_sigma: 0.05


# sampling
sampling:
mean: 0.0
scale: 1.0
min_sigma: 0.0
method: euler
num_steps: 25

# classifier-free guidance
null_condition_probability: 0.1
cfg_strength: 4.5

# checkpoint paths to external modules
vae_16k_ckpt: ./ext_weights/v1-16.pth
vae_44k_ckpt: ./ext_weights/v1-44.pth
bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt
synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth
70 changes: 70 additions & 0 deletions config/data/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
VGGSound:
root: ../data/video
subset_name: sets/vgg3-train.tsv
fps: 8
height: 384
width: 384
sample_duration_sec: 8.0

VGGSound_test:
root: ../data/video
subset_name: sets/vgg3-test.tsv
fps: 8
height: 384
width: 384
sample_duration_sec: 8.0

VGGSound_val:
root: ../data/video
subset_name: sets/vgg3-val.tsv
fps: 8
height: 384
width: 384
sample_duration_sec: 8.0

ExtractedVGG:
tsv: ../data/v1-16-memmap/vgg-train.tsv
memmap_dir: ../data/v1-16-memmap/vgg-train

ExtractedVGG_test:
tag: test
gt_cache: ../data/eval-test-v2
output_subdir: null
tsv: ../data/v1-16-memmap/vgg-test.tsv
memmap_dir: ../data/v1-16-memmap/vgg-test

ExtractedVGG_val:
tag: val
gt_cache: ../data/eval-val-v2
output_subdir: val
tsv: ../data/v1-16-memmap/vgg-val.tsv
memmap_dir: ../data/v1-16-memmap/vgg-val

AudioCaps:
tsv: ../data/v1-16-memmap/audiocaps.tsv
memmap_dir: ../data/v1-16-memmap/audiocaps

AudioSetSL:
tsv: ../data/v1-16-memmap/audioset_sl.tsv
memmap_dir: ../data/v1-16-memmap/audioset_sl

BBCSound:
tsv: ../data/v1-16-memmap/bbcsound.tsv
memmap_dir: ../data/v1-16-memmap/bbcsound

FreeSound:
tsv: ../data/v1-16-memmap/freesound.tsv
memmap_dir: ../data/v1-16-memmap/freesound

Clotho:
tsv: ../data/v1-16-memmap/clotho.tsv
memmap_dir: ../data/v1-16-memmap/clotho

Example_video:
tsv: ./training/example_output/memmap/vgg-example.tsv
memmap_dir: ./training/example_output/memmap/vgg-example

Example_audio:
tsv: ./training/example_output/memmap/audio-example.tsv
memmap_dir: ./training/example_output/memmap/audio-example

17 changes: 17 additions & 0 deletions config/eval_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
defaults:
- base_config
- override hydra/job_logging: custom-simplest
- _self_

hydra:
run:
dir: ./output/${exp_id}
output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra

exp_id: ${model}
dataset: audiocaps
duration_s: 8.0

# for inference, this is the per-GPU batch size
batch_size: 16
output_name: null
20 changes: 20 additions & 0 deletions config/eval_data/base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
AudioCaps:
audio_path: ../data/AudioCaps-test-audioldm-ver
# a csv file, with a header row of 'name' and 'caption'
# name should match the audio file name without extension
csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv

AudioCaps_full:
audio_path: ../data/AudioCaps-test-full-ver
# a csv file, with a header row of 'name' and 'caption'
# name should match the audio file name without extension
csv_path: ../data/AudioCaps-test-full-ver/data.csv

MovieGen:
video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio
jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata

VGGSound:
video_path: ../data/test-videos
# from the official released csv file
csv_path: ../data/vggsound.csv
32 changes: 32 additions & 0 deletions config/hydra/job_logging/custom-eval.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# python logging configuration for tasks
version: 1
formatters:
simple:
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
datefmt: '%Y-%m-%d %H:%M:%S'
colorlog:
'()': 'colorlog.ColoredFormatter'
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
datefmt: '%Y-%m-%d %H:%M:%S'
log_colors:
DEBUG: purple
INFO: green
WARNING: yellow
ERROR: red
CRITICAL: red
handlers:
console:
class: logging.StreamHandler
formatter: colorlog
stream: ext://sys.stdout
file:
class: logging.FileHandler
formatter: simple
# absolute file path
filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
mode: w
root:
level: INFO
handlers: [console, file]

disable_existing_loggers: false
Loading

0 comments on commit a758040

Please sign in to comment.