-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add training and batch evaluation scripts
- Loading branch information
1 parent
d294fb8
commit a758040
Showing
56 changed files
with
201,176 additions
and
80 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
import logging | ||
import os | ||
from pathlib import Path | ||
|
||
import hydra | ||
import torch | ||
import torch.distributed as distributed | ||
import torchaudio | ||
from hydra.core.hydra_config import HydraConfig | ||
from omegaconf import DictConfig | ||
from tqdm import tqdm | ||
|
||
from mmaudio.data.data_setup import setup_eval_dataset | ||
from mmaudio.eval_utils import ModelConfig, all_model_cfg, generate | ||
from mmaudio.model.flow_matching import FlowMatching | ||
from mmaudio.model.networks import MMAudio, get_my_mmaudio | ||
from mmaudio.model.utils.features_utils import FeaturesUtils | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = True | ||
torch.backends.cudnn.allow_tf32 = True | ||
|
||
local_rank = int(os.environ['LOCAL_RANK']) | ||
world_size = int(os.environ['WORLD_SIZE']) | ||
log = logging.getLogger() | ||
|
||
|
||
@torch.inference_mode() | ||
@hydra.main(version_base='1.3.2', config_path='config', config_name='eval_config.yaml') | ||
def main(cfg: DictConfig): | ||
device = 'cuda' | ||
torch.cuda.set_device(local_rank) | ||
|
||
if cfg.model not in all_model_cfg: | ||
raise ValueError(f'Unknown model variant: {cfg.model}') | ||
model: ModelConfig = all_model_cfg[cfg.model] | ||
model.download_if_needed() | ||
seq_cfg = model.seq_cfg | ||
|
||
run_dir = Path(HydraConfig.get().run.dir) | ||
if cfg.output_name is None: | ||
output_dir = run_dir / cfg.dataset | ||
else: | ||
output_dir = run_dir / f'{cfg.dataset}-{cfg.output_name}' | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# load a pretrained model | ||
seq_cfg.duration = cfg.duration_s | ||
net: MMAudio = get_my_mmaudio(cfg.model).to(device).eval() | ||
net.load_weights(torch.load(model.model_path, map_location=device, weights_only=True)) | ||
log.info(f'Loaded weights from {model.model_path}') | ||
net.update_seq_lengths(seq_cfg.latent_seq_len, seq_cfg.clip_seq_len, seq_cfg.sync_seq_len) | ||
log.info(f'Latent seq len: {seq_cfg.latent_seq_len}') | ||
log.info(f'Clip seq len: {seq_cfg.clip_seq_len}') | ||
log.info(f'Sync seq len: {seq_cfg.sync_seq_len}') | ||
|
||
# misc setup | ||
rng = torch.Generator(device=device) | ||
rng.manual_seed(cfg.seed) | ||
fm = FlowMatching(cfg.sampling.min_sigma, | ||
inference_mode=cfg.sampling.method, | ||
num_steps=cfg.sampling.num_steps) | ||
|
||
feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path, | ||
synchformer_ckpt=model.synchformer_ckpt, | ||
enable_conditions=True, | ||
mode=model.mode, | ||
bigvgan_vocoder_ckpt=model.bigvgan_16k_path, | ||
need_vae_encoder=False) | ||
feature_utils = feature_utils.to(device).eval() | ||
|
||
if cfg.compile: | ||
net.preprocess_conditions = torch.compile(net.preprocess_conditions) | ||
net.predict_flow = torch.compile(net.predict_flow) | ||
feature_utils.compile() | ||
|
||
dataset, loader = setup_eval_dataset(cfg.dataset, cfg) | ||
|
||
with torch.amp.autocast(enabled=cfg.amp, dtype=torch.bfloat16, device_type=device): | ||
for batch in tqdm(loader): | ||
audios = generate(batch.get('clip_video', None), | ||
batch.get('sync_video', None), | ||
batch.get('caption', None), | ||
feature_utils=feature_utils, | ||
net=net, | ||
fm=fm, | ||
rng=rng, | ||
cfg_strength=cfg.cfg_strength, | ||
clip_batch_size_multiplier=64, | ||
sync_batch_size_multiplier=64) | ||
audios = audios.float().cpu() | ||
names = batch['name'] | ||
for audio, name in zip(audios, names): | ||
torchaudio.save(output_dir / f'{name}.flac', audio, seq_cfg.sampling_rate) | ||
|
||
|
||
def distributed_setup(): | ||
distributed.init_process_group(backend="nccl") | ||
local_rank = distributed.get_rank() | ||
world_size = distributed.get_world_size() | ||
log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}') | ||
return local_rank, world_size | ||
|
||
|
||
if __name__ == '__main__': | ||
distributed_setup() | ||
|
||
main() | ||
|
||
# clean-up | ||
distributed.destroy_process_group() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
defaults: | ||
- data: base | ||
- eval_data: base | ||
- override hydra/job_logging: custom-simplest | ||
- _self_ | ||
|
||
hydra: | ||
run: | ||
dir: ./output/${exp_id} | ||
output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra | ||
|
||
enable_email: False | ||
|
||
model: small_16k | ||
|
||
exp_id: default | ||
debug: False | ||
cudnn_benchmark: True | ||
compile: True | ||
amp: True | ||
weights: null | ||
checkpoint: null | ||
seed: 14159265 | ||
num_workers: 10 # per-GPU | ||
pin_memory: False # set to True if your system can handle it, i.e., have enough memory | ||
|
||
# NOTE: This DOSE NOT affect the model during inference in any way | ||
# they are just for the dataloader to fill in the missing data in multi-modal loading | ||
# to change the sequence length for the model, see networks.py | ||
data_dim: | ||
text_seq_len: 77 | ||
clip_dim: 1024 | ||
sync_dim: 768 | ||
text_dim: 1024 | ||
|
||
# ema configuration | ||
ema: | ||
enable: True | ||
sigma_rels: [0.05, 0.1] | ||
update_every: 1 | ||
checkpoint_every: 5_000 | ||
checkpoint_folder: ./output/${exp_id}/ema_ckpts | ||
default_output_sigma: 0.05 | ||
|
||
|
||
# sampling | ||
sampling: | ||
mean: 0.0 | ||
scale: 1.0 | ||
min_sigma: 0.0 | ||
method: euler | ||
num_steps: 25 | ||
|
||
# classifier-free guidance | ||
null_condition_probability: 0.1 | ||
cfg_strength: 4.5 | ||
|
||
# checkpoint paths to external modules | ||
vae_16k_ckpt: ./ext_weights/v1-16.pth | ||
vae_44k_ckpt: ./ext_weights/v1-44.pth | ||
bigvgan_vocoder_ckpt: ./ext_weights/best_netG.pt | ||
synchformer_ckpt: ./ext_weights/synchformer_state_dict.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
VGGSound: | ||
root: ../data/video | ||
subset_name: sets/vgg3-train.tsv | ||
fps: 8 | ||
height: 384 | ||
width: 384 | ||
sample_duration_sec: 8.0 | ||
|
||
VGGSound_test: | ||
root: ../data/video | ||
subset_name: sets/vgg3-test.tsv | ||
fps: 8 | ||
height: 384 | ||
width: 384 | ||
sample_duration_sec: 8.0 | ||
|
||
VGGSound_val: | ||
root: ../data/video | ||
subset_name: sets/vgg3-val.tsv | ||
fps: 8 | ||
height: 384 | ||
width: 384 | ||
sample_duration_sec: 8.0 | ||
|
||
ExtractedVGG: | ||
tsv: ../data/v1-16-memmap/vgg-train.tsv | ||
memmap_dir: ../data/v1-16-memmap/vgg-train | ||
|
||
ExtractedVGG_test: | ||
tag: test | ||
gt_cache: ../data/eval-test-v2 | ||
output_subdir: null | ||
tsv: ../data/v1-16-memmap/vgg-test.tsv | ||
memmap_dir: ../data/v1-16-memmap/vgg-test | ||
|
||
ExtractedVGG_val: | ||
tag: val | ||
gt_cache: ../data/eval-val-v2 | ||
output_subdir: val | ||
tsv: ../data/v1-16-memmap/vgg-val.tsv | ||
memmap_dir: ../data/v1-16-memmap/vgg-val | ||
|
||
AudioCaps: | ||
tsv: ../data/v1-16-memmap/audiocaps.tsv | ||
memmap_dir: ../data/v1-16-memmap/audiocaps | ||
|
||
AudioSetSL: | ||
tsv: ../data/v1-16-memmap/audioset_sl.tsv | ||
memmap_dir: ../data/v1-16-memmap/audioset_sl | ||
|
||
BBCSound: | ||
tsv: ../data/v1-16-memmap/bbcsound.tsv | ||
memmap_dir: ../data/v1-16-memmap/bbcsound | ||
|
||
FreeSound: | ||
tsv: ../data/v1-16-memmap/freesound.tsv | ||
memmap_dir: ../data/v1-16-memmap/freesound | ||
|
||
Clotho: | ||
tsv: ../data/v1-16-memmap/clotho.tsv | ||
memmap_dir: ../data/v1-16-memmap/clotho | ||
|
||
Example_video: | ||
tsv: ./training/example_output/memmap/vgg-example.tsv | ||
memmap_dir: ./training/example_output/memmap/vgg-example | ||
|
||
Example_audio: | ||
tsv: ./training/example_output/memmap/audio-example.tsv | ||
memmap_dir: ./training/example_output/memmap/audio-example | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
defaults: | ||
- base_config | ||
- override hydra/job_logging: custom-simplest | ||
- _self_ | ||
|
||
hydra: | ||
run: | ||
dir: ./output/${exp_id} | ||
output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra | ||
|
||
exp_id: ${model} | ||
dataset: audiocaps | ||
duration_s: 8.0 | ||
|
||
# for inference, this is the per-GPU batch size | ||
batch_size: 16 | ||
output_name: null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
AudioCaps: | ||
audio_path: ../data/AudioCaps-test-audioldm-ver | ||
# a csv file, with a header row of 'name' and 'caption' | ||
# name should match the audio file name without extension | ||
csv_path: ../data/AudioCaps-test-audioldm-ver/data.csv | ||
|
||
AudioCaps_full: | ||
audio_path: ../data/AudioCaps-test-full-ver | ||
# a csv file, with a header row of 'name' and 'caption' | ||
# name should match the audio file name without extension | ||
csv_path: ../data/AudioCaps-test-full-ver/data.csv | ||
|
||
MovieGen: | ||
video_path: ../data/MovieGen/MovieGenAudioBenchSfx/video_with_audio | ||
jsonl_path: ../data/MovieGen/MovieGenAudioBenchSfx/metadata | ||
|
||
VGGSound: | ||
video_path: ../data/test-videos | ||
# from the official released csv file | ||
csv_path: ../data/vggsound.csv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# python logging configuration for tasks | ||
version: 1 | ||
formatters: | ||
simple: | ||
format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s' | ||
datefmt: '%Y-%m-%d %H:%M:%S' | ||
colorlog: | ||
'()': 'colorlog.ColoredFormatter' | ||
format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s' | ||
datefmt: '%Y-%m-%d %H:%M:%S' | ||
log_colors: | ||
DEBUG: purple | ||
INFO: green | ||
WARNING: yellow | ||
ERROR: red | ||
CRITICAL: red | ||
handlers: | ||
console: | ||
class: logging.StreamHandler | ||
formatter: colorlog | ||
stream: ext://sys.stdout | ||
file: | ||
class: logging.FileHandler | ||
formatter: simple | ||
# absolute file path | ||
filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log | ||
mode: w | ||
root: | ||
level: INFO | ||
handlers: [console, file] | ||
|
||
disable_existing_loggers: false |
Oops, something went wrong.