diff --git a/.gitignore b/.gitignore index 264d9d6c..a7e7d2cc 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,8 @@ wandb/ log/ *.log outputs/ -data/ \ No newline at end of file +data/ + +.gitignore +examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_noself.sh +examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b_copy.sh \ No newline at end of file diff --git a/examples/asr_librispeech/asr_config.py b/examples/asr_librispeech/asr_config.py index fea75dcc..280280dc 100644 --- a/examples/asr_librispeech/asr_config.py +++ b/examples/asr_librispeech/asr_config.py @@ -15,7 +15,10 @@ class ModelConfig: encoder_projector_ds_rate: int = 5 modal: str = "audio" normalize: Optional[bool] = field(default=False, metadata={ - "help": "whether inpit is normalized, used for models such as wavlm" + "help": "whether input is normalized, used for models such as wavlm" + }) + encoder_type: str = field(default="finetune", metadata={ + "help": "whether model is only pretrained or finetuned, used for models such as hubert" }) @dataclass @@ -97,7 +100,7 @@ class DataConfig: "help": "80 for whisper large v1 and v2, 128 for v3" }) normalize: Optional[bool] = field(default=False, metadata={ - "help": "whether inpit is normalized, used for models such as wavlm" + "help": "whether input is normalized, used for models such as wavlm" }) @dataclass diff --git a/examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b.sh b/examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b.sh new file mode 100755 index 00000000..c196d66c --- /dev/null +++ b/examples/asr_librispeech/scripts/decode_hubert_xtralarge_linear_vicuna_7b.sh @@ -0,0 +1,61 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/asr_librispeech + +speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/hubert_ckpt/hubert_xtralarge_ll60k_finetune_ls960.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/experiments_hubert/vicuna-7b-v1.5-hubert_xtralarge_ll60k_finetune_ls960 +ckpt_path=$output_dir/asr_epoch_1_step_1000 +split=librispeech_test_clean +val_data_path=/nfs/maziyang.mzy/data/librispeech/${split}.jsonl +decode_log=$ckpt_path/decode_${split}_beam4 + +# -m debugpy --listen 5678 --wait-for-client +python $code_dir/inference_asr_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + ++model_config.llm_name="vicuna-7b-v1.5" \ + ++model_config.llm_path=$llm_path \ + ++model_config.llm_dim=4096 \ + ++model_config.encoder_name=hubert \ + ++model_config.normalize=true \ + ++dataset_config.normalize=true \ + ++model_config.encoder_projector_ds_rate=5 \ + ++model_config.encoder_path=$speech_encoder_path \ + ++model_config.encoder_dim=1280 \ + ++model_config.encoder_type=finetune \ + ++model_config.encoder_projector=linear \ + ++dataset_config.dataset=speech_dataset \ + ++dataset_config.val_data_path=$val_data_path \ + ++dataset_config.input_type=raw \ + ++dataset_config.inference_mode=true \ + ++dataset_config.prompt="Transcribe speech to text. " \ + ++train_config.model_name=asr \ + ++train_config.freeze_encoder=true \ + ++train_config.freeze_llm=true \ + ++train_config.batching_strategy=custom \ + ++train_config.num_epochs=1 \ + ++train_config.val_batch_size=1 \ + ++train_config.num_workers_dataloader=0 \ + ++train_config.output_dir=$output_dir \ + ++decode_log=$decode_log \ + ++ckpt_path=$ckpt_path/model.pt \ + # ++peft_ckpt=$ckpt_path \ + # ++train_config.use_peft=true \ + # ++train_config.peft_config.r=32 \ + # ++dataset_config.normalize=true \ + # ++model_config.encoder_projector=q-former \ + # ++dataset_config.fix_length_audio=64 \ + +python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc +python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc +python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer diff --git a/examples/asr_librispeech/scripts/finetune_hubert_xtralarge_linear_vicuna_7b.sh b/examples/asr_librispeech/scripts/finetune_hubert_xtralarge_linear_vicuna_7b.sh new file mode 100755 index 00000000..73747629 --- /dev/null +++ b/examples/asr_librispeech/scripts/finetune_hubert_xtralarge_linear_vicuna_7b.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0,1,2,3 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/asr_librispeech + +speech_encoder_path=/nfs/yangguanrou.ygr/ckpts/hubert_ckpt/hubert_xtralarge_ll60k_finetune_ls960.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 +train_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_train_960h.jsonl +val_data_path=/nfs/maziyang.mzy/data/librispeech/librispeech_dev_other.jsonl + +output_dir=/root/tmp/vicuna-7b-v1.5-librispeech-linear-steplrwarmupkeep1e-4-hubert-xtralarge-$(date +"%Y%m%d") + +hydra_args=" +hydra.run.dir=$output_dir \ +++model_config.llm_name=vicuna-7b-v1.5 \ +++model_config.llm_path=$llm_path \ +++model_config.llm_dim=4096 \ +++model_config.encoder_name=hubert \ +++model_config.normalize=true \ +++dataset_config.normalize=true \ +++model_config.encoder_projector_ds_rate=5 \ +++model_config.encoder_path=$speech_encoder_path \ +++model_config.encoder_dim=1280 \ +++model_config.encoder_type=finetune \ +++model_config.encoder_projector=linear \ +++dataset_config.dataset=speech_dataset \ +++dataset_config.train_data_path=$train_data_path \ +++dataset_config.val_data_path=$val_data_path \ +++dataset_config.input_type=raw \ +++dataset_config.prompt=\"Transcribe speech to text. \" \ +++train_config.model_name=asr \ +++train_config.num_epochs=3 \ +++train_config.freeze_encoder=true \ +++train_config.freeze_llm=true \ +++train_config.batching_strategy=custom \ +++train_config.warmup_steps=1000 \ +++train_config.total_steps=100000 \ +++train_config.lr=1e-4 \ +++train_config.validation_interval=2000 \ +++train_config.batch_size_training=6 \ +++train_config.val_batch_size=6 \ +++train_config.num_workers_dataloader=0 \ +++train_config.output_dir=$output_dir \ +++metric=acc \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 4 \ + --master_port=29503 \ + $code_dir/finetune_asr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + ++train_config.enable_fsdp=false \ + ++train_config.enable_ddp=true \ + ++train_config.use_fp16=true \ + $hydra_args +fi diff --git a/examples/vsr_LRS3/README.md b/examples/vsr_LRS3/README.md new file mode 100644 index 00000000..0fbc66b9 --- /dev/null +++ b/examples/vsr_LRS3/README.md @@ -0,0 +1,30 @@ +# VSR_LRS3 + +## Performance and checkpoints +We only train the linear projector in this recipe. +Encoder | Projector | LLM | test +|---|---|---|---| +[AV-HuBERT Large + Self-Training](https://dl.fbaipublicfiles.com/avhubert/model/lrs3_vox/vsr/self_large_vox_433h.pt) | [Linear]()(~18.88M) | [vicuna-7b-v1.5](https://huggingface.co/lmsys/vicuna-7b-v1.5) | 29.47 + + +## Data preparation +Follow the steps in [preparation](https://github.com/facebookresearch/av_hubert/tree/main/avhubert/preparation) of av_hubert to pre-process LRS3 dataset + +## Environment +Use the specific fairseq version of [av_hubert](https://github.com/facebookresearch/av_hubert), which is compatible with hydra-core versions below 1.0.7 and omegaconf versions below 2.0.6. + + +## Decode with checkpoints +``` +bash decode_avhubert_vo_vicuna_7b.sh +``` +Modify the path including `speech_encoder_path`, `llm_path`, `output_dir`, `ckpt_path` and `decode_log` in the script when you run the shell script. + +## Train a new model + +### Use the visual part of AV-HuBERT Large as the encoder +``` +bash finetune_avhubert_vo_vicuna_7b.sh +``` + + diff --git a/examples/vsr_LRS3/conf/ds_config.json b/examples/vsr_LRS3/conf/ds_config.json new file mode 100644 index 00000000..7ea70e4a --- /dev/null +++ b/examples/vsr_LRS3/conf/ds_config.json @@ -0,0 +1,19 @@ +{ + "train_micro_batch_size_per_gpu": 4, + "gradient_accumulation_steps": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-4 + } + }, + "fp16": { + "enabled": true + }, + "zero_optimization": { + "stage": 3, + "offload_optimizer": { + "device": "cpu" + } + } +} \ No newline at end of file diff --git a/examples/vsr_LRS3/conf/prompt.yaml b/examples/vsr_LRS3/conf/prompt.yaml new file mode 100644 index 00000000..c6576d83 --- /dev/null +++ b/examples/vsr_LRS3/conf/prompt.yaml @@ -0,0 +1,3 @@ +dataset_config: + # we put prompt here, because the hydra override in shell script only support a small subset of chars + prompt: "Transcribe the silent speech in this video to text by lip-reading the speaker's clear and visible lip movements." diff --git a/examples/vsr_LRS3/finetune_vsr.py b/examples/vsr_LRS3/finetune_vsr.py new file mode 100644 index 00000000..39c18c36 --- /dev/null +++ b/examples/vsr_LRS3/finetune_vsr.py @@ -0,0 +1,45 @@ +from slam_llm.pipeline.finetune import main as train + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from vsr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + +@hydra.main(config_name=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + def to_plain_list(cfg_item): + if isinstance(cfg_item, ListConfig): + return OmegaConf.to_container(cfg_item, resolve=True) + elif isinstance(cfg_item, DictConfig): + return {k: to_plain_list(v) for k, v in cfg_item.items()} + else: + return cfg_item + + # kwargs = to_plain_list(cfg) + kwargs = cfg + log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if kwargs.get("debug", False): + import pdb; + pdb.set_trace() + + train(kwargs) + + +if __name__ == "__main__": + main_hydra() \ No newline at end of file diff --git a/examples/vsr_LRS3/inference_vsr_batch.py b/examples/vsr_LRS3/inference_vsr_batch.py new file mode 100644 index 00000000..a84d186d --- /dev/null +++ b/examples/vsr_LRS3/inference_vsr_batch.py @@ -0,0 +1,53 @@ +from slam_llm.pipeline.inference_batch import main as inference + +import hydra +import logging +from dataclasses import dataclass, field +from omegaconf import DictConfig, ListConfig, OmegaConf +from typing import Optional +from vsr_config import ModelConfig, TrainConfig, DataConfig, LogConfig, FSDPConfig + + +@dataclass +class RunConfig: + dataset_config: DataConfig = field(default_factory=DataConfig) + model_config: ModelConfig = field(default_factory=ModelConfig) + train_config: TrainConfig = field(default_factory=TrainConfig) + log_config: LogConfig = field(default_factory=LogConfig) + fsdp_config: FSDPConfig = field(default_factory=FSDPConfig) + debug: bool = field(default=False, metadata={"help": "Use pdb when true"}) + metric: str = field(default="acc", metadata={"help": "The metric for evaluation"}) + decode_log: str = field( + default="output/decode_log", + metadata={"help": "The prefix for the decode output"}, + ) + ckpt_path: str = field( + default="output/model.pt", metadata={"help": "The path to projector checkpoint"} + ) + peft_ckpt: Optional[str] = field( + default=None, + metadata={ + "help": "The path to peft checkpoint, should be a directory including adapter_config.json" + }, + ) + + +@hydra.main(config_name=None) +def main_hydra(cfg: DictConfig): + run_config = RunConfig() + cfg = OmegaConf.merge(run_config, cfg) + # kwargs = to_plain_list(cfg) + log_level = getattr(logging, cfg.get("log_level", "INFO").upper()) + + logging.basicConfig(level=log_level) + + if cfg.get("debug", False): + import pdb + + pdb.set_trace() + + inference(cfg) + + +if __name__ == "__main__": + main_hydra() diff --git a/examples/vsr_LRS3/model/slam_model_vsr.py b/examples/vsr_LRS3/model/slam_model_vsr.py new file mode 100644 index 00000000..0910d2ed --- /dev/null +++ b/examples/vsr_LRS3/model/slam_model_vsr.py @@ -0,0 +1,155 @@ +import torch +import os +import logging +from slam_llm.models.slam_model import ( + slam_model, + setup_tokenizer, + setup_encoder, + setup_encoder_projector, + setup_llm, +) +from slam_llm.utils.train_utils import print_model_size + +logger = logging.getLogger(__name__) + +def model_factory(train_config, model_config, **kwargs): + # return necessary components for training + tokenizer = setup_tokenizer(train_config, model_config, **kwargs) + + encoder = setup_encoder(train_config, model_config, **kwargs) + + # llm + llm = setup_llm(train_config, model_config, **kwargs) + + # projector + encoder_projector = setup_encoder_projector( + train_config, model_config, **kwargs + ) + model = slam_model_asr( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + ckpt_path = kwargs.get( + "ckpt_path", None + ) # FIX(MZY): load model ckpt(mainly projector, related to model_checkpointing/checkpoint_handler.py: save_model_checkpoint_peft) + if ckpt_path is not None: + logger.info("loading other parts from: {}".format(ckpt_path)) + ckpt_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(ckpt_dict, strict=False) + + print_model_size( + model, + train_config, + ( + int(os.environ["RANK"]) + if train_config.enable_fsdp or train_config.enable_ddp + else 0 + ), + ) + return model, tokenizer + + +class slam_model_asr(slam_model): + def __init__( + self, + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ): + super().__init__( + encoder, + llm, + encoder_projector, + tokenizer, + train_config, + model_config, + **kwargs, + ) + + + @torch.no_grad() + def inference( + self, + wav_path=None, + prompt=None, + generation_config=None, + logits_processor=None, + stopping_criteria=None, + prefix_allowed_tokens_fn=None, + synced_gpus=None, + assistant_model=None, + streamer=None, + negative_prompt_ids=None, + negative_prompt_attention_mask=None, + **kwargs, + ): + # inference for asr model + + device = kwargs.get("device", "cuda") + if os.path.exists(wav_path): # Audio-Text QA + import whisper + + audio_raw = whisper.load_audio(wav_path) + audio_raw = whisper.pad_or_trim(audio_raw) + + mel_size = getattr( + self.dataset_config, "mel_size", 80 + ) # 80 for large v1 and v2, 128 for large v3 + audio_mel = ( + whisper.log_mel_spectrogram(audio_raw, n_mels=mel_size) + .permute(1, 0)[None, :, :] + .to(device) + ) + + encoder_outs = self.encoder.extract_variable_length_features( + audio_mel.permute(0, 2, 1) + ) + + if self.model_config.encoder_projector == "q-former": + audio_mel_post_mask = torch.ones( + encoder_outs.size()[:-1], dtype=torch.long + ).to(encoder_outs.device) + encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) + if self.model_config.encoder_projector == "linear": + encoder_outs = self.encoder_projector(encoder_outs) + else: # Text QA + encoder_outs = torch.empty( + 1, 0, self.llm.model.embed_tokens.embedding_dim + ).to(device) + + prompt = "USER: {}\n ASSISTANT:".format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(device) + + if hasattr(self.llm.model, "embed_tokens"): + inputs_embeds = self.llm.model.embed_tokens(prompt_ids) + elif hasattr(self.llm.model.model, "embed_tokens"): + inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids) + else: + inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids) + + inputs_embeds = torch.cat( + (encoder_outs, inputs_embeds[None, :, :]), dim=1 + ) # [audio,prompt] + + attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to( + inputs_embeds.device + ) + + # generate + model_outputs = self.generate( + inputs_embeds=inputs_embeds, attention_mask=attention_mask, **kwargs + ) + + return model_outputs diff --git a/examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b.sh b/examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b.sh new file mode 100755 index 00000000..42eb3b78 --- /dev/null +++ b/examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b.sh @@ -0,0 +1,55 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/vsr_LRS3 + +speech_encoder_path=/nfs/yangguanrou.ygr/codes/av_hubert/self_large_vox_433h_new.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/experiments_avhubert/vicuna-7b-v1.5-large_vox_433h-tri-dataset-tiaocan_again +ckpt_path=$output_dir/asr/850 + +decode_log=$ckpt_path/decode_${split}_beam4_again + +# -m debugpy --listen 5678 --wait-for-client +python $code_dir/inference_vsr_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + +model_config.llm_name="vicuna-7b-v1.5" \ + +model_config.llm_path=$llm_path \ + +model_config.llm_dim=4096 \ + +model_config.encoder_name=av_hubert \ + +model_config.encoder_projector_ds_rate=5 \ + +model_config.encoder_path=$speech_encoder_path \ + +model_config.encoder_dim=1024 \ + +model_config.encoder_projector=cov1d-linear \ + +dataset_config.dataset=avhubert_dataset \ + +dataset_config.inference_mode=true \ + +dataset_config.test_split=test \ + +train_config.model_name=vsr \ + +train_config.freeze_encoder=true \ + +train_config.freeze_llm=true \ + +train_config.batching_strategy=custom \ + +train_config.num_epochs=1 \ + +train_config.val_batch_size=8 \ + +train_config.num_workers_dataloader=0 \ + +train_config.output_dir=$output_dir \ + +decode_log=$decode_log \ + +ckpt_path=$ckpt_path/model.pt \ + # +peft_ckpt=$ckpt_path \ + # +train_config.use_peft=true \ + # +train_config.peft_config.r=32 \ + # +dataset_config.normalize=true \ + # +model_config.encoder_projector=q-former \ + # +dataset_config.fix_length_audio=64 \ + +python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc +python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc +python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer diff --git a/examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_noself.sh b/examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_noself.sh new file mode 100755 index 00000000..07f13827 --- /dev/null +++ b/examples/vsr_LRS3/scripts/decode_avhubert_vo_vicuna_7b_noself.sh @@ -0,0 +1,55 @@ +#!/bin/bash +#export PYTHONPATH=/root/whisper:$PYTHONPATH +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=0 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/vsr_LRS3 + +speech_encoder_path=/nfs/yangguanrou.ygr/codes/av_hubert/large_vox_433h.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/experiments_avhubert/vicuna-7b-v1.5-large_vox_433h-tri-dataset-tiaocan_again +ckpt_path=$output_dir/asr/850 + +decode_log=$ckpt_path/decode_${split}_beam4_noself + +# -m debugpy --listen 5678 --wait-for-client +python $code_dir/inference_vsr_batch.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + hydra.run.dir=$ckpt_path \ + +model_config.llm_name="vicuna-7b-v1.5" \ + +model_config.llm_path=$llm_path \ + +model_config.llm_dim=4096 \ + +model_config.encoder_name=av_hubert \ + +model_config.encoder_projector_ds_rate=5 \ + +model_config.encoder_path=$speech_encoder_path \ + +model_config.encoder_dim=1024 \ + +model_config.encoder_projector=cov1d-linear \ + +dataset_config.dataset=avhubert_dataset \ + +dataset_config.inference_mode=true \ + +dataset_config.test_split=test \ + +train_config.model_name=vsr \ + +train_config.freeze_encoder=true \ + +train_config.freeze_llm=true \ + +train_config.batching_strategy=custom \ + +train_config.num_epochs=1 \ + +train_config.val_batch_size=8 \ + +train_config.num_workers_dataloader=0 \ + +train_config.output_dir=$output_dir \ + +decode_log=$decode_log \ + +ckpt_path=$ckpt_path/model.pt \ + # +peft_ckpt=$ckpt_path \ + # +train_config.use_peft=true \ + # +train_config.peft_config.r=32 \ + # +dataset_config.normalize=true \ + # +model_config.encoder_projector=q-former \ + # +dataset_config.fix_length_audio=64 \ + +python src/slam_llm/utils/whisper_tn.py ${decode_log}_gt ${decode_log}_gt.proc +python src/slam_llm/utils/whisper_tn.py ${decode_log}_pred ${decode_log}_pred.proc +python src/slam_llm/utils/compute_wer.py ${decode_log}_gt.proc ${decode_log}_pred.proc ${decode_log}.proc.wer diff --git a/examples/vsr_LRS3/scripts/finetune_avhubert_vo_vicuna_7b.sh b/examples/vsr_LRS3/scripts/finetune_avhubert_vo_vicuna_7b.sh new file mode 100755 index 00000000..e3bf181c --- /dev/null +++ b/examples/vsr_LRS3/scripts/finetune_avhubert_vo_vicuna_7b.sh @@ -0,0 +1,73 @@ +#!/bin/bash +# export PYTHONPATH=/root/whisper:$PYTHONPATH +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=3 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +# debug setting for multiple gpus +# export NCCL_DEBUG=INFO +# export NCCL_DEBUG_SUBSYS=ALL +# export TORCH_DISTRIBUTED_DEBUG=INFO + +run_dir=/root/SLAM-LLM +cd $run_dir +code_dir=examples/vsr_LRS3 + +speech_encoder_path=/nfs/yangguanrou.ygr/codes/av_hubert/self_large_vox_433h_new.pt +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/root/tmp/vicuna-7b-v1.5-large_vox_433h-$(date +"%Y%m%d") + +hydra_args=" +hydra.run.dir=$output_dir \ ++model_config.llm_name=vicuna-7b-v1.5 \ ++model_config.llm_path=$llm_path \ ++model_config.llm_dim=4096 \ ++model_config.encoder_name=av_hubert \ ++model_config.encoder_path=$speech_encoder_path \ ++model_config.encoder_dim=1024 \ ++model_config.encoder_projector=cov1d-linear \ ++model_config.encoder_projector_ds_rate=5 \ ++dataset_config.dataset=avhubert_dataset \ ++dataset_config.input_type=raw \ ++dataset_config.labels=[\"wrd\"] \ ++train_config.model_name=vsr \ ++train_config.num_epochs=10 \ ++train_config.freeze_encoder=true \ ++train_config.freeze_llm=true \ ++train_config.batching_strategy=custom \ ++train_config.warmup_steps=1000 \ ++train_config.total_steps=70000 \ ++train_config.lr=5e-3 \ ++train_config.validation_interval=2000 \ ++train_config.batch_size_training=12 \ ++train_config.val_batch_size=12 \ ++train_config.num_workers_dataloader=0 \ ++train_config.output_dir=$output_dir \ ++metric=acc \ +" + +# -m debugpy --listen 5678 --wait-for-client +if [[ $CUDA_VISIBLE_DEVICES != *","* ]]; then + python -m debugpy --listen 5678 --wait-for-client $code_dir/finetune_vsr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + $hydra_args +else + torchrun \ + --nnodes 1 \ + --nproc_per_node 2 \ + --master_port=29503 \ + $code_dir/finetune_vsr.py \ + --config-path "conf" \ + --config-name "prompt.yaml" \ + +train_config.enable_fsdp=false \ + +train_config.enable_ddp=true \ + +train_config.use_fp16=true \ + $hydra_args +fi + + +#+dataset_config.dataset=[\"wrd\"] \ \ No newline at end of file diff --git a/examples/vsr_LRS3/vsr_config.py b/examples/vsr_LRS3/vsr_config.py new file mode 100644 index 00000000..3c051a01 --- /dev/null +++ b/examples/vsr_LRS3/vsr_config.py @@ -0,0 +1,133 @@ +from dataclasses import dataclass, field +from typing import Optional, List +@dataclass +class ModelConfig: + file: str = "examples/vsr_LRS3/model/slam_model_vsr.py:model_factory" + llm_name: str = "vicuna-7b-v1.5" + llm_path: str = "PATH/to/Vicuna/7B" + llm_type: str = "decoder_only" + llm_dim: int = 4096 + encoder_name: Optional[str] = "av_hubert" + encoder_path: Optional[str] = "PATH/to/self_large_vox_433h.pt" + encoder_dim: int = 1024 + encoder_projector: str = "cov1d-linear" + encoder_projector_ds_rate: int = 5 + +@dataclass +class PeftConfig: + peft_method: str = "lora" # None , llama_adapter, prefix + r: int = 8 + lora_alpha: int = 32 + target_modules: List = field(default_factory=lambda: [ "q_proj", "v_proj" ]) + bias: str = "none" + task_type: str = "CAUSAL_LM" + lora_dropout: float = 0.05 + inference_mode: bool = False + +@dataclass +class TrainConfig: + model_name:str = "av_hubert" + enable_ddp:bool = False + enable_deepspeed:bool = False + enable_fsdp:bool = False + low_cpu_fsdp:bool = False + run_validation:bool = True + batch_size_training:int = 4 + batching_strategy:str = field(default="packing", metadata={ + "help":"alternative: padding" + }) # + context_length:int = 4096 + gradient_accumulation_steps:int = 1 + num_epochs:int = 3 + num_workers_dataloader:int = 1 + warmup_steps:int = 1000 + total_steps:int = 100000 + validation_interval:int = 1000 + lr:float = 1e-4 + weight_decay:float = 0.0 + gamma:float = 0.85 + seed:int = 42 + use_fp16:bool = False + mixed_precision:bool = True + val_batch_size:int = 1 + + use_peft:bool = False + peft_config:PeftConfig = field(default_factory=PeftConfig) + output_dir:str = "PATH/to/save/PEFT/model" + freeze_layers:bool = False + num_freeze_layers:int = 1 + quantization:bool = False + one_gpu:bool = False + save_model:bool = True + dist_checkpoint_root_folder:str = "PATH/to/save/FSDP/model" # will be used if using FSDP + dist_checkpoint_folder:str = "fine-tuned" # will be used if using FSDP + save_optimizer:bool = False # will be used if using FSDP + use_fast_kernels:bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + run_test_during_validation:bool = False + run_test_during_validation_file:str = "test.wav" + run_test_during_validation_prompt:str = "<|ASR|>" + freeze_llm:bool = field(default=False, metadata={ + "help": "whether to freeze llm when finetuning, should be True when use peft finetuning" + }) + freeze_encoder:bool = False + +@dataclass +class DataConfig: + fix_length_audio: int = -1 + inference_mode: bool = False + dataset: str = "avhubert_dataset" + file: str = "src/slam_llm/datasets/avhubert_dataset.py:get_audio_dataset" + data: str = "/nfs/yangguanrou.ygr/LRS_new/433h_data/" + train_split: str = "train" + test_split: str = "val" + labels: List = field(default_factory=lambda:["wrd"]) + label_dir: str = "/nfs/yangguanrou.ygr/LRS_new/433h_data" + label_rate: int = -1 + is_s2s: bool = True + noise_wav: Optional[str] = None + noise_snr: str = "0.0" + noise_num: int = 1 + sample_rate: int = 16000 + normalize: bool = True + enable_padding: bool = False + max_sample_size: int = 500 + min_sample_size: int = 0 + max_trim_sample_size: int = 500 + single_target: bool = True + random_crop: bool = False + pad_audio: bool = True + pdb: bool = False + stack_order_audio: int = 1 + skip_verify: bool = False + # image_aug: True + image_crop_size: int = 88 + image_mean: float = 0.421 + image_std: float = 0.165 + noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) + fine_tuning: bool = True + modal: str = "VO" + modalities: List = field(default_factory=lambda: ['video']) + shuffle: bool = True + prompt: str = "Transcribe the silent speech in this video to text by lip-reading the speaker's clear and visible lip movements." + +@dataclass +class FSDPConfig: + mixed_precision: bool = True + use_fp16: bool = False + # sharding_strategy = "FULL_SHARD" #ShardingStrategy = ShardingStrategy.FULL_SHARD + sharding_strategy: str = "NO_SHARD" #ShardingStrategy.NO_SHARD #MZY: set NO_SHARD when use DDP + checkpoint_type: str = "SHARDED_STATE_DICT" # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size. + fsdp_activation_checkpointing: bool = True + fsdp_cpu_offload: bool = False + pure_bf16: bool = False + optimizer: str = "AdamW" + +@dataclass +class LogConfig: + use_wandb: bool = False + wandb_dir: str = "/root/test_wandb" + wandb_entity_name: str = "project_name" + wandb_project_name: str = "project_name" + wandb_exp_name: str = "exp_name" + log_file: str = "/root/test.log" + log_interval: int = 5 diff --git a/scripts/compute_wer.sh b/scripts/compute_wer.sh index 00acf7c9..28663de9 100755 --- a/scripts/compute_wer.sh +++ b/scripts/compute_wer.sh @@ -5,6 +5,8 @@ preds="/nfs/maziyang.mzy/exps/vicuna-7b-v1.5-finetune-asr-ds10-proj2048-steplrwa # python src/llama_recipes/utils/preprocess_text.py ${preds} ${preds}.proc # python src/llama_recipes/utils/compute_wer.py ${trans} ${preds}.proc ${preds}.proc.wer +trans=/nfs/yangguanrou.ygr/experiments_avhubert/vicuna-7b-v1.5-large_vox_433h-tri-dataset-tiaocan_again/asr/850/decode__beam4_noself_gt +preds=/nfs/yangguanrou.ygr/experiments_avhubert/vicuna-7b-v1.5-large_vox_433h-tri-dataset-tiaocan_again/asr/850/decode__beam4_noself_pred python src/llama_recipes/utils/whisper_tn.py ${trans} ${trans}.proc python src/llama_recipes/utils/llm_tn.py ${preds} ${preds}.proc diff --git a/scripts/finetune_avhubert_tri_dataset_tiaocan.sh b/scripts/finetune_avhubert_tri_dataset_tiaocan.sh new file mode 100644 index 00000000..5814778f --- /dev/null +++ b/scripts/finetune_avhubert_tri_dataset_tiaocan.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# export PYTHONPATH=/root/fairseq:$PYTHONPATH +export CUDA_VISIBLE_DEVICES=2 +export TOKENIZERS_PARALLELISM=false +# export CUDA_LAUNCH_BLOCKING=1 +export OMP_NUM_THREADS=1 + +export MASTER_ADDR=localhost # Or the actual IP if it's a remote cluster +export MASTER_PORT=12346 # A free port number +export WORLD_SIZE=1 # Assuming you have 4 GPUs +export RANK=0 +export LOCAL_RANK=0 + +cd /root/SLAM-LLM + + +speech_encoder_path=/nfs/yangguanrou.ygr/codes/av_hubert/self_large_vox_433h_new.pt + +llm_path=/nfs/maziyang.mzy/models/vicuna-7b-v1.5 + +output_dir=/nfs/yangguanrou.ygr/debug +# ckpt_path=/nfs/maziyang.mzy/exps/llama-2-hf-finetune-asr-ds5-proj2048-lr1e-4-whisper-prompt-paddinglrfix8000-20240106/asr/2/model.pt + + +python -m debugpy --listen 5678 --wait-for-client src/llama_recipes/pipeline/finetune.py \ +--config-path "/root/SLAM-LLM/scripts/conf_avsr" \ +--config-name "avsr.yaml" \ +hydra.run.dir=$output_dir \ +model_config.llm_name="vicuna-7b-v1.5" \ +model_config.llm_path=$llm_path \ +model_config.llm_dim=4096 \ +model_config.encoder_name=av_hubert \ +model_config.encoder_path=$speech_encoder_path \ +model_config.encoder_dim=1024 \ +model_config.encoder_projector=cov1d-linear \ +model_config.encoder_projector_ds_rate=5 \ +dataset_config.dataset=avhubert_dataset \ +dataset_config.file="src/llama_recipes/datasets/avhubert_dataset.py:get_audio_dataset" \ +model_config.modal=VO \ +train_config.model_name=asr \ +train_config.freeze_encoder=true \ +train_config.freeze_llm=true \ +train_config.batching_strategy=custom \ +train_config.warmup_steps=1000 \ +train_config.total_steps=70000 \ +train_config.lr=5e-3 \ +train_config.scheduler=tri \ +train_config.validation_interval=2000 \ +train_config.batch_size_training=8 \ +train_config.val_batch_size=8 \ +train_config.num_workers_dataloader=0 \ +train_config.output_dir=$output_dir \ +train_config.enable_fsdp=false \ +train_config.enable_ddp=true \ +train_config.use_fp16=true \ ++metric=acc \ +# log_config.log_file=/$output_dir/train.log \ +# log_config.use_wandb=true \ +# log_config.wandb_dir=$output_dir \ +# log_config.wandb_entity_name=yanghaha \ +# log_config.wandb_project_name=slam-llm-vox \ +# log_config.wandb_exp_name=vicuna-7b-v1.5-large_vox_433h-tri-dataset-tiaocan \ +# log_config.log_interval=10 \ + + + + + + +# cd /root +# cp -r SLAM-LLM/ /nfs/yangguanrou.ygr/codes/ \ No newline at end of file diff --git a/src/slam_llm/datasets/avhubert_dataset.py b/src/slam_llm/datasets/avhubert_dataset.py new file mode 100644 index 00000000..a0db43fc --- /dev/null +++ b/src/slam_llm/datasets/avhubert_dataset.py @@ -0,0 +1,595 @@ +import os.path as osp +import random +import json, yaml +import copy + +import numpy as np +from scipy import signal +import soundfile as sf + +import torch +import torchaudio +from torch.utils.data import Dataset +import whisper + +import itertools +import os +import sys +import time +from typing import Any, List, Optional, Union +import torch.nn.functional as F +from fairseq.data import data_utils +from python_speech_features import logfbank +from scipy.io import wavfile +import slam_llm.utils.custom_utils as custom_utils + + +import logging +logger = logging.getLogger(__name__) + +def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1): + def is_audio_label_aligned(audio_dur, label_durs): + return all([abs(audio_dur - label_dur) max_keep: + n_long += 1 + elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])): + n_unaligned += 1 + else: + video_path = items[1] + audio_path = items[2] + audio_id = items[0] + names.append((video_path, audio_path+':'+audio_id)) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + code_lengths = [len(line.encode("utf-8")) for line in f] + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [len(line.rstrip().split()) for line in f] + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + +class AVHubertdataset(torch.utils.data.Dataset): + def __init__(self,dataset_config,tokenizer=None,split='train',): + super().__init__() + self.dataset_config = dataset_config + self.tokenizer = tokenizer + self.IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss + self.prompt = dataset_config.get("prompt", None) + self.prompt_template = "USER: {}\n ASSISTANT:" + self.answer_template = "{}" + self.fix_length_audio = dataset_config.get("fix_length_audio", -1) + self.inference_mode = dataset_config.get("inference_mode", False) + self.data= dataset_config.data + self.manifest = f"{self.data}/{split}.tsv" + + paths = [ + f"{dataset_config.label_dir}/{split}.{l}" for l in dataset_config.labels + ] + label_paths=paths + image_aug = False + noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if dataset_config.noise_wav is not None else None, eval(dataset_config.noise_snr) + noise_num = dataset_config.noise_num + store_labels=False + + + self.label_rates = ( + [dataset_config.label_rate for _ in range(len(label_paths))] + if isinstance(dataset_config.label_rate, int) + else dataset_config.label_rate + ) + self.modalities = set(dataset_config.modalities) + self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(self.manifest, dataset_config.max_sample_size, dataset_config.min_sample_size, frame_rate=dataset_config.sample_rate, label_paths=label_paths, label_rates=self.label_rates) + self.sample_rate = dataset_config.sample_rate + self.stack_order_audio = dataset_config.stack_order_audio + self.shuffle = dataset_config.shuffle + self.random_crop = dataset_config.random_crop + + self.num_labels = len(label_paths) + self.single_target = dataset_config.single_target + self.store_labels = False + self.is_s2s = dataset_config.is_s2s + self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], dataset_config.noise_prob, noise_snr, noise_num + + assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)" + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + if not dataset_config.skip_verify: + for label_path, label_rate in zip(label_paths, self.label_rates): + verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot) + else: + logger.info(f"Skip label alignment verifying") + + self.max_sample_size = ( + dataset_config.max_sample_size if dataset_config.max_sample_size is not None else sys.maxsize + ) + self.pad_audio = dataset_config.pad_audio + self.normalize = dataset_config.normalize + if image_aug: + self.transform = custom_utils.Compose([ + custom_utils.Normalize( 0.0,255.0 ), + custom_utils.RandomCrop((dataset_config.image_crop_size, dataset_config.image_crop_size)), + custom_utils.HorizontalFlip(0.5), + custom_utils.Normalize(dataset_config.image_mean, dataset_config.image_std) ]) + else: + self.transform = custom_utils.Compose([ + custom_utils.Normalize( 0.0,255.0 ), + custom_utils.CenterCrop((dataset_config.image_crop_size, dataset_config.image_crop_size)), + custom_utils.Normalize(dataset_config.image_mean, dataset_config.image_std) ]) + logger.info(f"image transform: {self.transform}") + + logger.info( + f"pad_audio={self.pad_audio}, random_crop={self.random_crop}, " + f"normalize={self.normalize}, max_sample_size={self.max_sample_size}, " + f"seqs2seq data={self.is_s2s},") + logger.info( + f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}" + ) + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def load_feature(self, mix_name): + """ + Load image and audio feature + Returns: + video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F] + """ + def stacker(feats, stack_order): + """ + Concatenating consecutive audio frames + Args: + feats - numpy.ndarray of shape [T, F] + stack_order - int (number of neighboring frames to concatenate + Returns: + feats - numpy.ndarray of shape [T', F'] + """ + feat_dim = feats.shape[1] + if len(feats) % stack_order != 0: + res = stack_order - len(feats) % stack_order + res = np.zeros([res, feat_dim]).astype(feats.dtype) + feats = np.concatenate([feats, res], axis=0) + feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim) + return feats + video_fn, audio_fn = mix_name + if 'video' in self.modalities: + video_feats = self.load_video(video_fn) # [T, H, W, 1] + else: + video_feats = None + if 'audio' in self.modalities: + audio_fn = audio_fn.split(':')[0] + sample_rate, wav_data = wavfile.read(audio_fn) + assert sample_rate == 16_000 and len(wav_data.shape) == 1 + if np.random.rand() < self.noise_prob: + wav_data = self.add_noise(wav_data) + audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F] + audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio] + else: + audio_feats = None + if audio_feats is not None and video_feats is not None: + diff = len(audio_feats) - len(video_feats) + if diff < 0: + audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)]) + elif diff > 0: + audio_feats = audio_feats[:-diff] + return video_feats, audio_feats + + def load_video(self, audio_name): + feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name)) + feats = self.transform(feats) + feats = np.expand_dims(feats, axis=-1) + return feats + + def select_noise(self): + rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num) + noise_wav = [] + for x in rand_indexes: + noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32)) + if self.noise_num == 1: + return noise_wav[0] + else: + min_len = min([len(x) for x in noise_wav]) + noise_wav = [x[:min_len] for x in noise_wav] + noise_wav = np.floor(np.stack(noise_wav).mean(axis=0)) + return noise_wav + + def add_noise(self, clean_wav): + clean_wav = clean_wav.astype(np.float32) + noise_wav = self.select_noise() + if type(self.noise_snr) == int or type(self.noise_snr) == float: + snr = self.noise_snr + elif type(self.noise_snr) == tuple: + snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1) + clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1)) + if len(clean_wav) > len(noise_wav): + ratio = int(np.ceil(len(clean_wav)/len(noise_wav))) + noise_wav = np.concatenate([noise_wav for _ in range(ratio)]) + if len(clean_wav) < len(noise_wav): + start = 0 + noise_wav = noise_wav[start: start + len(clean_wav)] + noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1)) + adjusted_noise_rms = clean_rms / (10**(snr/20)) + adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms) + mixed = clean_wav + adjusted_noise_wav + + #Avoid clipping noise + max_int16 = np.iinfo(np.int16).max + min_int16 = np.iinfo(np.int16).min + if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16: + if mixed.max(axis=0) >= abs(mixed.min(axis=0)): + reduction_rate = max_int16 / mixed.max(axis=0) + else : + reduction_rate = min_int16 / mixed.min(axis=0) + mixed = mixed * (reduction_rate) + mixed = mixed.astype(np.int16) + return mixed + + def __getitem__(self, index): + video_feats, audio_feats = self.load_feature(self.names[index]) + audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None + if self.normalize and 'audio' in self.modalities: + with torch.no_grad(): + audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:]) + labels = self.get_labels(index) + target = labels[0].replace("\n", "") + fid = self.names[index][1].split(':')[1] + + if self.dataset_config.modal=="AV": + prompt = "Transcribe video to text. " + elif self.dataset_config.modal=="AO": + prompt = "Transcribe speech to text. " + elif self.dataset_config.modal=="VO": + prompt = self.dataset_config.prompt #"Transcribe the silent speech in this video to text by lip-reading the speaker's clear and visible lip movements." + + prompt = self.prompt_template.format(prompt) + prompt_ids = self.tokenizer.encode(prompt) + prompt_length = len(prompt_ids) + + audio_length = video_feats.shape[0] + + if self.fix_length_audio > 0: + audio_length = self.fix_length_audio + else: + audio_length = audio_length // 5 # ad-hoc for 5x fc downsample + audio_pseudo = torch.full((audio_length,), -1) # placeholder + + if self.inference_mode: + prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64) + example_ids = torch.cat((audio_pseudo, prompt_ids)) # [audio,prompt] + example_mask = example_ids.ge(-1) # [True,True] + + return { + "id": index, + "input_ids": example_ids, + "attention_mask": example_mask, + 'audio_source': audio_feats, + "video_source": video_feats, + 'audio_length': audio_length, + 'key': fid, + 'target': target, + } + + answer = self.answer_template.format(target) + example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. + example_ids = self.tokenizer.encode(example) # [prompt,answer] + example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] + example_ids = torch.tensor( + example_ids, dtype=torch.int64 + ) + example_ids = torch.cat((audio_pseudo, example_ids)) # [audio,prompt,answer,eos] + + labels_ids = copy.deepcopy(example_ids) # [audio,prompt,answer,eos] + labels_ids[:audio_length + prompt_length] = -1 # [-1,-1,answer,eos]; + example_mask = example_ids.ge(-1) # FIX(GZF): [True,True,True,True] + + label_mask = labels_ids.ge(0) # [False,False,True,True] + example_ids[~example_mask] = 0 # [audio,prompt,answer,eos] + labels_ids[~label_mask] = self.IGNORE_INDEX # [-100,-100,answer,eos] + + return { + "id": index, + "input_ids": example_ids, + "labels": labels_ids, + "attention_mask": example_mask, + 'audio_source': audio_feats, + "video_source": video_feats, + 'audio_length': audio_length, + } + + def __len__(self): + return len(self.sizes) + + def crop_to_max_size(self, wav, target_size, start=None): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + # longer utterances + if start is None: + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + else: + end = start + target_size + return wav[start:end], start + + def collator(self, samples): + samples = [s for s in samples if s["id"] is not None] + if len(samples) == 0: + return {} + + audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples] + if audio_source[0] is None: + audio_source = None + if video_source[0] is None: + video_source = None + if audio_source is not None: + audio_sizes = [len(s) for s in audio_source] + else: + audio_sizes = [len(s) for s in video_source] + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + if audio_source is not None: + collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size) + else: + collated_audios, audio_starts = None, None + if video_source is not None: + collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts) + else: + collated_videos = None + + input_ids_max_length = max([s['input_ids'].shape[0] for s in samples]) + input_ids = torch.stack([self.pad(s['input_ids'], input_ids_max_length, self.tokenizer.pad_token_id) for s in samples]) + attention_mask = torch.stack([self.pad(s['attention_mask'], input_ids_max_length, False) for s in samples]) + + modality_mask = torch.zeros_like(attention_mask) + for line, sample in enumerate(samples): + modality_mask[line, :sample['audio_length']] = 1 + + if self.inference_mode: + keys = [s['key'] for s in samples] + targets = [s['target'] for s in samples] + + return { + 'input_ids': input_ids, + 'attention_mask': attention_mask, + 'modality_mask': modality_mask, + 'keys': keys, + 'targets': targets, + + "audio": collated_audios, + "audio_mask": padding_mask, + "visual": collated_videos, + "visual_mask": padding_mask, + } + + labels = torch.stack([self.pad(s['labels'], input_ids_max_length, self.IGNORE_INDEX) for s in samples]) + + return { + 'input_ids': input_ids, + 'labels': labels, + 'attention_mask': attention_mask, + 'modality_mask': modality_mask, + + "audio": collated_audios, + "audio_mask": padding_mask, + "visual": collated_videos, + "visual_mask": padding_mask, + } + + def collater_audio(self, audios, audio_size, audio_starts=None): + audio_feat_shape = list(audios[0].shape[1:]) + collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape) + padding_mask = ( + torch.BoolTensor(len(audios), audio_size).fill_(False) + ) + start_known = audio_starts is not None + audio_starts = [0 for _ in audios] if not start_known else audio_starts + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat( + [audio, audio.new_full([-diff]+audio_feat_shape, 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size, audio_starts[i] if start_known else None + ) + if len(audios[0].shape) == 2: + collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T] + else: + collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W] + return collated_audios, padding_mask, audio_starts + + def collater_frm_label( + self, targets, audio_size, audio_starts, label_rate, pad + ): + assert label_rate > 0 + s2f = label_rate / self.sample_rate # num label per sample + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label_s2s(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos() + targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False) + prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True) + return (targets_, prev_output_tokens), lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1: + if self.is_s2s: + targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad) + else: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] + + def pad(self, sequence, max_length, padding_idx=0): + if isinstance(sequence, (int, list, tuple)): + if len(sequence) < max_length: + sequence = sequence + [padding_idx] * (max_length - len(sequence)) + else: + sequence = sequence[:max_length] + elif isinstance(sequence, torch.Tensor): + if len(sequence) < max_length: + sequence = torch.cat( + (sequence, torch.full(([max_length - len(sequence)] + list(sequence.size())[1:]), padding_idx))) + else: + sequence = sequence[:max_length] + else: + raise Exception("Type mismatch during padding!") + return sequence + +def get_audio_dataset(dataset_config, tokenizer, split): + dataset = AVHubertdataset(dataset_config, tokenizer, split) + return dataset \ No newline at end of file diff --git a/src/slam_llm/models/avhubert/__init__.py b/src/slam_llm/models/avhubert/__init__.py new file mode 100644 index 00000000..9495f2bb --- /dev/null +++ b/src/slam_llm/models/avhubert/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hubert import * # noqa +from .hubert_asr import * # noqa +from .hubert_dataset import * +from .hubert_pretraining import * +from .hubert_criterion import * diff --git a/src/slam_llm/models/avhubert/decoder.py b/src/slam_llm/models/avhubert/decoder.py new file mode 100644 index 00000000..78de423a --- /dev/null +++ b/src/slam_llm/models/avhubert/decoder.py @@ -0,0 +1,243 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace +import contextlib +import copy +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from dataclasses import dataclass, field +from omegaconf import MISSING, II, open_dict +from typing import Any, Optional + +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.tasks import FairseqTask +from fairseq.models import ( + BaseFairseqModel, + FairseqEncoder, + FairseqEncoderDecoderModel, + FairseqIncrementalDecoder, + register_model, +) +# from fairseq.models.wav2vec.wav2vec2 import MASKING_DISTRIBUTION_CHOICES +from fairseq.modules import ( + LayerNorm, + PositionalEmbedding, + TransformerDecoderLayer, +) + + +class TransformerDecoder(FairseqIncrementalDecoder): + """ + Transformer decoder consisting of *args.decoder_layers* layers. Each layer + is a :class:`TransformerDecoderLayer`. + + Args: + args (argparse.Namespace): parsed command-line arguments + dictionary (~fairseq.data.Dictionary): decoding dictionary + embed_tokens (torch.nn.Embedding): output embedding + no_encoder_attn (bool, optional): whether to attend to encoder outputs + (default: False). + """ + + def __init__( + self, + cfg, + dictionary, + embed_tokens, + no_encoder_attn=False, + ): + super().__init__(dictionary) + + self.dropout = cfg.decoder_dropout + self.share_input_output_embed = cfg.share_decoder_input_output_embed + + input_embed_dim = embed_tokens.embedding_dim + embed_dim = cfg.decoder_embed_dim + self.output_embed_dim = cfg.decoder_embed_dim + + self.layerdrop = cfg.decoder_layerdrop + + padding_idx = embed_tokens.padding_idx + self.max_target_positions = cfg.max_target_positions + + self.embed_tokens = embed_tokens + # self.embed_scale = math.sqrt(embed_dim) # todo: try with input_embed_dim + self.embed_scale = 1.0 if cfg.no_scale_embedding else math.sqrt(embed_dim) + + self.project_in_dim = ( + Linear(input_embed_dim, embed_dim, bias=False) + if embed_dim != input_embed_dim + else None + ) + + self.embed_positions = ( + PositionalEmbedding( + cfg.max_target_positions, + embed_dim, + padding_idx, + learned=cfg.decoder_learned_pos, + ) + if not cfg.no_token_positional_embeddings + else None + ) + + # TODO: update this when transformer gets converted to dataclass configs + transformer_cfg = copy.deepcopy(cfg) + # with open_dict(transformer_cfg): + transformer_cfg.dropout = transformer_cfg.decoder_dropout + transformer_cfg.attention_dropout = ( + transformer_cfg.decoder_attention_dropout + ) + transformer_cfg.activation_dropout = ( + transformer_cfg.decoder_activation_dropout + ) + + self.layers = nn.ModuleList([]) + self.layers.extend( + [ + TransformerDecoderLayer(transformer_cfg, no_encoder_attn) + for _ in range(transformer_cfg.decoder_layers) + ] + ) + + if not self.share_input_output_embed: + self.embed_out = nn.Parameter( + torch.Tensor(len(dictionary), self.output_embed_dim) + ) + nn.init.normal_(self.embed_out, mean=0, std=self.output_embed_dim ** -0.5) + + if transformer_cfg.decoder_normalize_before: + self.layer_norm = LayerNorm(embed_dim) + else: + self.layer_norm = None + + def forward( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Args: + prev_output_tokens (LongTensor): previous decoder outputs of shape + `(batch, tgt_len)`, for teacher forcing + encoder_out (Tensor, optional): output from the encoder, used for + encoder-side attention + incremental_state (dict): dictionary used for storing state during + :ref:`Incremental decoding` + + Returns: + tuple: + - the decoder's output of shape `(batch, tgt_len, vocab)` + - a dictionary with any model-specific outputs + """ + prev_output_tokens = prev_output_tokens.long() + x, extra = self.extract_features( + prev_output_tokens, encoder_out, incremental_state + ) + x = self.output_layer(x) + return x, extra + + def extract_features( + self, prev_output_tokens, encoder_out=None, incremental_state=None, **unused + ): + """ + Similar to *forward* but only return features. + + Returns: + tuple: + - the decoder's features of shape `(batch, tgt_len, embed_dim)` + - a dictionary with any model-specific outputs + """ + + # embed positions + positions = ( + self.embed_positions( + prev_output_tokens, incremental_state=incremental_state + ) + if self.embed_positions is not None + else None + ) + + if incremental_state is not None: + prev_output_tokens = prev_output_tokens[:, -1:] + if positions is not None: + positions = positions[:, -1:] + + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(prev_output_tokens) + + if self.project_in_dim is not None: + x = self.project_in_dim(x) + + if positions is not None: + x += positions + x = F.dropout(x, p=self.dropout, training=self.training) + + # B x T x C -> T x B x C + x = x.transpose(0, 1) + attn = None + + inner_states = [x] + + # decoder layers + for layer in self.layers: + dropout_probability = np.random.random() + if not self.training or (dropout_probability > self.layerdrop): + x, attn, _ = layer( + x, + encoder_out["encoder_out"] if encoder_out is not None else None, + encoder_out["padding_mask"] if encoder_out is not None else None, + incremental_state, + self_attn_mask=self.buffered_future_mask(x) + if incremental_state is None + else None, + ) + inner_states.append(x) + + if self.layer_norm: + x = self.layer_norm(x) + + # T x B x C -> B x T x C + x = x.transpose(0, 1) + + return x, {"attn": attn, "inner_states": inner_states} + + def output_layer(self, features, **kwargs): + """Project features to the vocabulary size.""" + # project back to size of vocabulary + emb_mat = self.embed_tokens.weight if self.share_input_output_embed else self.embed_out + return torch.matmul(features, emb_mat.transpose(0, 1)) + # if self.share_input_output_embed: + # return F.linear(features, self.embed_tokens.weight) + # else: + # return F.linear(features, self.embed_out) + + def max_positions(self): + """Maximum output length supported by the decoder.""" + if self.embed_positions is None: + return self.max_target_positions + return min(self.max_target_positions, self.embed_positions.max_positions) + + def buffered_future_mask(self, tensor): + dim = tensor.size(0) + if ( + not hasattr(self, "_future_mask") + or self._future_mask is None + or self._future_mask.device != tensor.device + or self._future_mask.size(0) < dim + ): + self._future_mask = torch.triu( + utils.fill_with_neg_inf(tensor.new(dim, dim)), 1 + ) + return self._future_mask[:dim, :dim] + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + diff --git a/src/slam_llm/models/avhubert/hubert.py b/src/slam_llm/models/avhubert/hubert.py new file mode 100644 index 00000000..45386698 --- /dev/null +++ b/src/slam_llm/models/avhubert/hubert.py @@ -0,0 +1,792 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os,sys +import logging +from typing import Dict, List, Optional, Tuple + +import numpy as np + +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from fairseq import utils +from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.dictionary import Dictionary +from fairseq.dataclass import ChoiceEnum, FairseqDataclass +from fairseq.models import BaseFairseqModel, register_model +from fairseq.models.wav2vec.wav2vec2 import ( + ConvFeatureExtractionModel, + TransformerEncoder, +) +from fairseq.modules import GradMultiply, LayerNorm +from copy import deepcopy + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + from hubert_pretraining import ( + AVHubertPretrainingConfig, + AVHubertPretrainingTask, + ) + from resnet import ResEncoder + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, + ) + from utils import compute_mask_indices + from decoder import TransformerDecoder + +else: + from .hubert_pretraining import ( + AVHubertPretrainingConfig, + AVHubertPretrainingTask, + ) + from .resnet import ResEncoder + from .utils import compute_mask_indices + from .decoder import TransformerDecoder + +from omegaconf import II + +logger = logging.getLogger(__name__) + +EXTRACTOR_MODE_CHOICES = ChoiceEnum(["default", "layer_norm"]) +MASKING_DISTRIBUTION_CHOICES = ChoiceEnum( + ["static", "uniform", "normal", "poisson"] +) +# LAYER_TYPE_CHOICES = ChoiceEnum(["transformer", "conformer", "trf_adp"]) + + +@dataclass +class AVHubertConfig(FairseqDataclass): + label_rate: int = II("task.label_rate") + input_modality: str = II("task.input_modality") + extractor_mode: EXTRACTOR_MODE_CHOICES = field( + default="default", + metadata={ + "help": "mode for feature extractor. default has a single group " + "norm with d groups in the first conv block, whereas layer_norm " + "has layer norms in every block (meant to use with normalize=True)" + }, + ) + encoder_layers: int = field( + default=12, metadata={"help": "num encoder layers in the transformer"} + ) + encoder_embed_dim: int = field( + default=768, metadata={"help": "encoder embedding dimension"} + ) + encoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "encoder embedding dimension for FFN"} + ) + encoder_attention_heads: int = field( + default=12, metadata={"help": "num encoder attention heads"} + ) + activation_fn: ChoiceEnum(utils.get_available_activation_fns()) = field( + default="gelu", metadata={"help": "activation function to use"} + ) + + # dropouts + dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for the transformer"}, + ) + attention_dropout: float = field( + default=0.1, + metadata={"help": "dropout probability for attention weights"}, + ) + activation_dropout: float = field( + default=0.0, + metadata={"help": "dropout probability after activation in FFN"}, + ) + encoder_layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a tarnsformer layer"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + dropout_features: float = field( + default=0.0, + metadata={ + "help": "dropout to apply to the features (after feat extr)" + }, + ) + + final_dim: int = field( + default=0, + metadata={ + "help": "project final representations and targets to this many " + "dimensions. set to encoder_embed_dim is <= 0" + }, + ) + untie_final_proj: bool = field( + default=False, + metadata={"help": "use separate projection for each target"}, + ) + layer_norm_first: bool = field( + default=False, + metadata={"help": "apply layernorm first in the transformer"}, + ) + conv_feature_layers: str = field( + default="[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2", + metadata={ + "help": "string describing convolutional feature extraction " + "layers in form of a python list that contains " + "[(dim, kernel_size, stride), ...]" + }, + ) + conv_bias: bool = field( + default=False, metadata={"help": "include bias in conv encoder"} + ) + logit_temp: float = field( + default=0.1, metadata={"help": "temperature to divide logits by"} + ) + target_glu: bool = field( + default=False, metadata={"help": "adds projection + glu to targets"} + ) + feature_grad_mult: float = field( + default=1.0, + metadata={"help": "multiply feature extractor var grads by this"}, + ) + + # masking + mask_length_audio: int = field(default=10, metadata={"help": "mask length"}) + mask_prob_audio: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_length_image: int = field(default=10, metadata={"help": "mask length"}) + mask_prob_image: float = field( + default=0.65, + metadata={"help": "probability of replacing a token with mask"}, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose mask length"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + mask_min_space: int = field( + default=1, + metadata={ + "help": "min space between spans (if no overlap is enabled)" + }, + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indicesh" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + mask_channel_min_space: int = field( + default=1, + metadata={ + "help": "min space between spans (if no overlap is enabled)" + }, + ) + + # positional embeddings + conv_pos: int = field( + default=128, + metadata={ + "help": "number of filters for convolutional positional embeddings" + }, + ) + conv_pos_groups: int = field( + default=16, + metadata={ + "help": "number of groups for convolutional positional embedding" + }, + ) + + latent_temp: Tuple[float, float, float] = field( + default=(2, 0.5, 0.999995), + metadata={"help": "legacy (to be removed)"}, + ) + + # loss computation + skip_masked: bool = field( + default=False, + metadata={"help": "skip computing losses over masked frames"}, + ) + skip_nomask: bool = field( + default=False, + metadata={"help": "skip computing losses over unmasked frames"}, + ) + resnet_relu_type: str = field(default='prelu', metadata={"help": 'relu type for resnet'}) + resnet_weights: Optional[str] = field(default=None, metadata={"help": 'resnet weights'}) + sim_type: str = field(default='cosine', metadata={"help": 'similarity type'}) + + sub_encoder_layers: int = field(default=0, metadata={'help': 'number of transformer layers for single modality'}) + audio_feat_dim: int = field(default=-1, metadata={'help': 'audio feature dimension'}) + modality_dropout: float = field(default=0, metadata={'help': 'drop one modality'}) + audio_dropout: float = field(default=0, metadata={'help': 'drop audio feature'}) + modality_fuse: str = field(default='concat', metadata={'help': 'fusing two modalities: add,concat'}) + selection_type : str = field(default='same_other_seq', metadata={'help': 'type of selectig images, same_other_seq: replace masked span with span from another sequence, same_seq: repace masked span with span of the same sequence'}) + masking_type : str = field(default='input', metadata={'help': 'input or feature masking'}) + + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field( + default=6, metadata={"help": "num of decoder layers"} + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, + metadata={"help": "apply layernorm before each decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings " + "(outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.1, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.1, + metadata={ + "help": "dropout probability for attention weights " + "inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, + metadata={"help": "share decoder input and output embeddings"}, + ) + no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'}) + + # # new fairseq + # required_seq_len_multiple: int = field( + # default=1, + # metadata={ + # "help": "pad the input to encoder such that the sequence length is divisible by multiple" + # }, + # ) + + # layer_type: LAYER_TYPE_CHOICES = field( + # default="transformer", metadata={"help": "layer type in encoder"} + # ) + +class SubModel(nn.Module): + def __init__(self, resnet=None, input_dim=None, cfg=None): + super().__init__() + self.resnet = resnet + self.proj = nn.Linear(input_dim, cfg.encoder_embed_dim) + self.encoder = TransformerEncoder(cfg) if cfg.encoder_layers > 0 else None + + def forward(self, x): #torch.Size([1, 1, 106, 112, 112]) + if self.resnet is not None: + x = self.resnet(x) #torch.Size([1, 512, 106]) #torch.Size([12, 26, 314]) + x = self.proj(x.transpose(1, 2)) #audio是 Linear(in_features=104, out_features=1024, bias=True) 太他妈扯了吧 + if self.encoder is not None: + x = self.encoder(x)[0].transpose(1, 2) + else: # + x = x.transpose(1, 2) + return x #torch.Size([1, 1024, 106]) + +@register_model("av_hubert", dataclass=AVHubertConfig) +class AVHubertModel(BaseFairseqModel): + def __init__( + self, + cfg: AVHubertConfig, + task_cfg: AVHubertPretrainingConfig, + dictionaries: List[Dictionary], + **kwargs + ) -> None: + super().__init__() + logger.info(f"HubertModel Config: {cfg}") + + feature_ds_rate = 1 + self.feat2tar_ratio = cfg.label_rate * feature_ds_rate / task_cfg.sample_rate + sub_cfg = deepcopy(cfg) + sub_cfg.encoder_layers = sub_cfg.sub_encoder_layers + resnet = ResEncoder(relu_type=cfg.resnet_relu_type, weights=cfg.resnet_weights) + self.feature_extractor_audio = SubModel(resnet=None, input_dim=cfg.audio_feat_dim, cfg=sub_cfg) + self.feature_extractor_video = SubModel(resnet=resnet, input_dim=resnet.backend_out, cfg=sub_cfg) + self.modality_dropout, self.audio_dropout = cfg.modality_dropout, cfg.audio_dropout + self.modality_fuse = cfg.modality_fuse + self.encoder_embed_dim = cfg.encoder_embed_dim + if self.modality_fuse == 'concat': + self.embed = cfg.encoder_embed_dim * 2 + elif self.modality_fuse == 'add': + self.embed = cfg.encoder_embed_dim + self.post_extract_proj = ( + nn.Linear(self.embed, cfg.encoder_embed_dim) + if self.embed != cfg.encoder_embed_dim + else None + ) + + self.mask_prob_image, self.mask_prob_audio = cfg.mask_prob_image, cfg.mask_prob_audio + self.mask_selection = cfg.mask_selection + self.mask_other = cfg.mask_other + self.mask_length_image, self.mask_length_audio = cfg.mask_length_image, cfg.mask_length_audio + self.no_mask_overlap = cfg.no_mask_overlap + self.mask_min_space = cfg.mask_min_space + + self.mask_channel_prob = cfg.mask_channel_prob + self.mask_channel_selection = cfg.mask_channel_selection + self.mask_channel_other = cfg.mask_channel_other + self.mask_channel_length = cfg.mask_channel_length + self.no_mask_channel_overlap = cfg.no_mask_channel_overlap + self.mask_channel_min_space = cfg.mask_channel_min_space + + self.dropout_input = nn.Dropout(cfg.dropout_input) + self.dropout_features = nn.Dropout(cfg.dropout_features) + + self.feature_grad_mult = cfg.feature_grad_mult + self.logit_temp = cfg.logit_temp + self.skip_masked = cfg.skip_masked + self.skip_nomask = cfg.skip_nomask + self.sim_type = cfg.sim_type + self.selection_type = cfg.selection_type + self.masking_type = cfg.masking_type + + final_dim = ( + cfg.final_dim if cfg.final_dim > 0 else cfg.encoder_embed_dim + ) + + self.mask_emb = nn.Parameter( + torch.FloatTensor(cfg.audio_feat_dim).uniform_() if self.masking_type == 'input' else torch.FloatTensor(cfg.encoder_embed_dim).uniform_() + ) + + self.encoder = TransformerEncoder(cfg) + self.layer_norm = LayerNorm(self.embed) + + self.target_glu = None + if cfg.target_glu: + self.target_glu = nn.Sequential( + nn.Linear(final_dim, final_dim * 2), nn.GLU() + ) + + self.untie_final_proj = cfg.untie_final_proj + if self.untie_final_proj: + self.final_proj = nn.Linear( + cfg.encoder_embed_dim, final_dim * len(dictionaries) + ) + else: + self.final_proj = nn.Linear(cfg.encoder_embed_dim, final_dim) + + # modules below are not needed during fine-tuning + if any([d is None for d in dictionaries]): + logger.info( + "cannot find dictionary. assume will be used for fine-tuning" + ) + else: + self.num_classes = [len(d) for d in dictionaries] + self.label_embs_concat = nn.Parameter( + torch.FloatTensor(sum(self.num_classes), final_dim) + ) + nn.init.uniform_(self.label_embs_concat) + + def upgrade_state_dict_named(self, state_dict, name): + """Upgrade a (possibly old) state dict for new versions of fairseq.""" + + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: AVHubertConfig, task: AVHubertPretrainingTask): + """Build a new model instance.""" + + kwargs = {} + model = AVHubertModel(cfg, task.cfg, task.dictionaries, **kwargs) + return model + + def apply_input_mask(self, x, padding_mask, target_list): + B, C, T = x.shape[:3] + is_audio = True if len(x.shape) == 3 else False + if is_audio: + mask_prob, mask_length = self.mask_prob_audio, self.mask_length_audio + else: + mask_prob, mask_length = self.mask_prob_image, self.mask_length_image + if mask_prob > 0: + + mask_indices, starts, ends, batch_indexes = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices_np = mask_indices + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = x.transpose(1, 2).contiguous() # [B, T, C, H, W] + if B == 1: + x[mask_indices] = 0 + elif is_audio: + x[mask_indices] = self.mask_emb + elif self.selection_type == 'same_other_seq': + perm = (torch.arange(B) + torch.randint(low=1, high=B, size=(1,))) % B + x_perm = x[perm] + x[mask_indices] = x_perm[mask_indices] + elif self.selection_type == 'same_seq': + batch_indexes_, other_indexes = [], [] + for batch_index, start, end in zip(batch_indexes, starts, ends): + length = end-start + other_start = np.setdiff1d(np.arange(T), np.arange(max(0, start-length), end)) + if len(other_start) > 0: + other_start = np.random.choice(other_start, size=1) + else: + other_start = 0 + other_end = other_start + length + other_indexes.append(np.arange(other_start, other_end).clip(max=T-1)) + batch_indexes_.append(np.zeros([length], dtype=np.int64)+batch_index) + batch_indexes, other_indexes = np.concatenate(batch_indexes_), np.concatenate(other_indexes) + x[mask_indices] = x[batch_indexes, other_indexes] + + x = x.transpose(1, 2).contiguous() + else: + mask_indices = None + + if self.mask_channel_prob > 0: + logger.info(f"No mask channel prob for input masking") + return x, mask_indices + + def apply_feature_mask(self, x, padding_mask, target_list): + B, T, C = x.shape + assert self.mask_prob_audio == self.mask_prob_image and self.mask_length_audio == self.mask_length_image, f"masking prob/length for image/audio be same for feature masking" + mask_prob, mask_length = self.mask_prob_audio, self.mask_length_image + if mask_prob > 0: + mask_indices, _, _, _ = compute_mask_indices( + (B, T), + padding_mask, + mask_prob, + mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x[mask_indices] = self.mask_emb + else: + mask_indices = None + + if self.mask_channel_prob > 0: + mask_channel_indices, _, _, _ = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x[mask_channel_indices] = 0 + + return x, mask_indices + + def forward_features(self, source: torch.Tensor, modality: str) -> torch.Tensor: + extractor = eval(f"self.feature_extractor_{modality}") + if self.feature_grad_mult > 0: + features = extractor(source) + if self.feature_grad_mult != 1.0: + features = GradMultiply.apply(features, self.feature_grad_mult) + else: + with torch.no_grad(): + features = extractor(source) + return features + + def forward_targets( + self, features: torch.Tensor, mask_indices: torch.Tensor, target_list: List[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Trim features to ensure labels exist and then get aligned labels + feat_tsz = features.size(2) + targ_tsz = min([t.size(1) for t in target_list]) + if self.feat2tar_ratio * feat_tsz > targ_tsz: + feat_tsz = int(targ_tsz / self.feat2tar_ratio) + features = features[..., :feat_tsz] + if mask_indices is not None: + mask_indices = mask_indices[..., :feat_tsz] + target_inds = torch.arange(feat_tsz).float() * self.feat2tar_ratio + target_list = [t[:, target_inds.long()] for t in target_list] + return features, mask_indices, target_list + + def forward_padding_mask( + self, features: torch.Tensor, padding_mask: torch.Tensor, + ) -> torch.Tensor: + extra = padding_mask.size(1) % features.size(1) + if extra > 0: + padding_mask = padding_mask[:, :-extra] + padding_mask = padding_mask.view( + padding_mask.size(0), features.size(1), -1 + ) + padding_mask = padding_mask.all(-1) + return padding_mask + + def compute_logits(self, feats, emb_mat): + # feats: [B, T, F], emb_mat: [V, F] + if self.sim_type == 'dot': + logits = torch.matmul(feats, emb_mat.transpose(0, 1)) + elif self.sim_type == 'cosine': + batch_size, timesteps, emb_dim = feats.size() + feats_ = feats.view(-1, emb_dim) + nom = (feats_.unsqueeze(dim=1) * emb_mat.unsqueeze(dim=0)).sum(dim=-1) # [B*T, V] + denom = (feats_**2).sum(dim=-1).sqrt().unsqueeze(dim=1) * (emb_mat**2).sum(dim=-1).sqrt().unsqueeze(dim=0) # [B*T, V] + logits = (nom/denom.clamp(min=1e-6)).view(batch_size, timesteps, -1) + else: + raise NotImplementedError + logits = logits / self.logit_temp + return logits + + def forward( + self, + source: torch.Tensor, + target_list: Optional[List[torch.Tensor]] = None, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = True, + features_only: bool = False, + output_layer: Optional[int] = None + ) -> Dict[str, torch.Tensor]: + """output layer is 1-based""" + src_audio, src_video = source['audio'], source['video'] + if mask and self.masking_type == 'input': + src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list) + src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list) + mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) + else: + src_audio, src_video, mask_indices = src_audio, src_video, None + + features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] + features_video = self.forward_features(src_video, modality='video') + modality_drop_prob, audio_drop_prob = np.random.random(), np.random.random() + if self.training: + if modality_drop_prob < self.modality_dropout: + if audio_drop_prob < self.audio_dropout: + features_audio = 0 * features_audio + else: + features_video = 0 * features_video + if self.modality_fuse == 'concat': + features = torch.cat([features_audio, features_video], dim=1) + elif self.modality_fuse == 'add': + features = features_audio + features_video + if target_list is not None: + features, mask_indices, target_list = self.forward_targets(features, mask_indices, target_list) + + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + + if padding_mask is not None: + padding_mask = self.forward_padding_mask(features, padding_mask) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) + + features = self.dropout_input(features) + if self.masking_type == 'feature' and mask: + x, mask_indices = self.apply_feature_mask(features, padding_mask, target_list) + else: + x = features + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + if features_only: + return {"x": x, "padding_mask": padding_mask, "features": features} + + label_embs_list = self.label_embs_concat.split(self.num_classes, 0) + proj_x = self.final_proj(x) + if self.untie_final_proj: + proj_x_list = proj_x.chunk(len(self.num_classes), dim=-1) + else: + proj_x_list = [proj_x for _ in self.num_classes] + logit_list = [self.compute_logits(proj, emb).view(-1, num_class) for proj, emb, num_class in zip(proj_x_list, label_embs_list, self.num_classes)] # [[B*T, V]] + mask, unmask = torch.logical_and(mask_indices, ~padding_mask).view(-1), torch.logical_and(~mask_indices, ~padding_mask).view(-1) # [B*T] + logit_m_list, logit_u_list = [logit[mask] for logit in logit_list], [logit[unmask] for logit in logit_list] + target_m_list, target_u_list = [target.view(-1)[mask].long() for target in target_list], [target.view(-1)[unmask].long() for target in target_list] + result = { + "logit_m_list": logit_m_list, + "logit_u_list": logit_u_list, + "target_m_list": target_m_list, + "target_u_list": target_u_list, + "padding_mask": padding_mask, + "features_pen": features_pen, + } + return result + + def extract_features( + self, + source: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + mask: bool = False, + ret_conv: bool = False, + output_layer: Optional[int] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + res = self.forward( + source, + padding_mask=padding_mask, + mask=mask, + features_only=True, + output_layer=output_layer, + ) + feature = res["features"] if ret_conv else res["x"] + return feature, res["padding_mask"] + + def extract_finetune(self, source, padding_mask=None, mask=False, ret_conv=False, output_layer=None): + src_audio, src_video = source['audio'], source['video'] #torch.Size([1, 1, 106, 112, 112]) + if mask and self.masking_type == 'input': + src_video, mask_indices_video = self.apply_input_mask(src_video, padding_mask, target_list=None) + src_audio, mask_indices_audio = self.apply_input_mask(src_audio, padding_mask, target_list=None) + mask_indices = torch.logical_or(mask_indices_audio, mask_indices_video) # mask_indices not used in fine-tuning + else: # + src_audio, src_video, mask_indices = src_audio, src_video, None + + if src_audio is not None and src_video is None: + features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] + features_video = features_audio.new_zeros(features_audio.size(0), self.encoder_embed_dim, features_audio.size(-1)) + elif src_audio is None and src_video is not None: + features_video = self.forward_features(src_video, modality='video') + features_audio = features_video.new_zeros(features_video.size(0), self.encoder_embed_dim, features_video.size(-1)) #全0! + elif src_audio is not None and src_video is not None: + features_video = self.forward_features(src_video, modality='video') #torch.Size([1, 1024, 106]) #scr torch.Size([12, 1, 314, 88, 88]) + features_audio = self.forward_features(src_audio, modality='audio') # features: [B, F, T] #torch.Size([12, 26, 314]) + + if self.modality_fuse == 'concat': # + features = torch.cat([features_audio, features_video], dim=1) #torch.Size([1, 2048, 106]) + elif self.modality_fuse == 'add': + features = features_audio + features_video + features_pen = features.float().pow(2).mean() + + features = features.transpose(1, 2) + features = self.layer_norm(features) + unmasked_features = features.clone() + + if padding_mask is not None: #features:torch.Size([1, 106, 2048]) + padding_mask = self.forward_padding_mask(features, padding_mask) #torch.Size([4, 154]) + + if self.post_extract_proj is not None: + features = self.post_extract_proj(features) #torch.Size([1, 106, 1024]) + + features = self.dropout_input(features) + unmasked_features = self.dropout_features(unmasked_features) + x = features + mask_indices = None + + # feature: (B, T, D), float + # target: (B, T), long + # x: (B, T, D), float + # padding_mask: (B, T), bool + # mask_indices: (B, T), bool + x, _ = self.encoder( + x, + padding_mask=padding_mask, + layer=None if output_layer is None else output_layer - 1 + ) + + return x, padding_mask #torch.Size([1, 106, 1024]), None + + + def get_extra_losses(self, net_output): + extra_losses = [] + names = [] + if "features_pen" in net_output: + extra_losses.append(net_output["features_pen"]) + names.append("features_pen") + + return extra_losses, names + + def remove_pretraining_modules(self): + self.target_glu = None + self.final_proj = None + + def get_logits(self, net_output, is_masked=True): + raise NotImplementedError + + def get_targets(self, net_output, is_masked=True): + raise NotImplementedError + + def compute_nce(self, x, pos, negs): + neg_is_pos = (pos == negs).all(-1) + pos = pos.unsqueeze(0) + targets = torch.cat([pos, negs], dim=0) + + logits = torch.cosine_similarity( + x.float(), targets.float(), dim=-1 + ).type_as(x) + logits /= self.logit_temp + if neg_is_pos.any(): + logits[1:][neg_is_pos] = float("-inf") + logits = logits.transpose(0, 1) # (num_x, num_cls+1) + return logits diff --git a/src/slam_llm/models/avhubert/hubert_asr.py b/src/slam_llm/models/avhubert/hubert_asr.py new file mode 100644 index 00000000..cebc69f0 --- /dev/null +++ b/src/slam_llm/models/avhubert/hubert_asr.py @@ -0,0 +1,523 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import sys,logging +import contextlib +import tempfile +from argparse import Namespace +from typing import Any, Optional + +import torch +import torch.nn as nn +from dataclasses import dataclass, field +from fairseq import checkpoint_utils, tasks, utils +from fairseq.dataclass import FairseqDataclass +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.models import BaseFairseqModel, FairseqEncoder, FairseqEncoderDecoderModel, register_model +from fairseq.models.hubert.hubert import MASKING_DISTRIBUTION_CHOICES +from fairseq.tasks import FairseqTask +from omegaconf import II, MISSING + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + from hubert import AVHubertModel + from decoder import TransformerDecoder +else: + from .hubert import AVHubertModel + from .decoder import TransformerDecoder + +logger = logging.getLogger(__name__) + + +@dataclass +class AVHubertAsrConfig(FairseqDataclass): + w2v_path: str = field( + default=MISSING, metadata={"help": "path to hubert model"} + ) + no_pretrained_weights: bool = field( + default=False, + metadata={"help": "if true, does not load pretrained weights"}, + ) + dropout_input: float = field( + default=0.0, + metadata={"help": "dropout to apply to the input (after feat extr)"}, + ) + final_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout after transformer and before final projection" + }, + ) + dropout: float = field( + default=0.0, + metadata={"help": "dropout probability inside hubert model"}, + ) + attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " + "inside hubert model" + }, + ) + activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside hubert model" + }, + ) + + # masking + apply_mask: bool = field( + default=False, metadata={"help": "apply masking during fine-tuning"} + ) + mask_length: int = field( + default=10, metadata={"help": "repeat the mask indices multiple times"} + ) + mask_prob: float = field( + default=0.5, + metadata={ + "help": "probability of replacing a token with mask " + "(normalized by length)" + }, + ) + mask_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", metadata={"help": "how to choose masks"} + ) + mask_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_overlap: bool = field( + default=False, metadata={"help": "whether to allow masks to overlap"} + ) + + # channel masking + mask_channel_length: int = field( + default=10, + metadata={"help": "length of the mask for features (channels)"}, + ) + mask_channel_prob: float = field( + default=0.0, + metadata={"help": "probability of replacing a feature with 0"}, + ) + mask_channel_selection: MASKING_DISTRIBUTION_CHOICES = field( + default="static", + metadata={"help": "how to choose mask length for channel masking"}, + ) + mask_channel_other: float = field( + default=0, + metadata={ + "help": "secondary mask argument " + "(used for more complex distributions), " + "see help in compute_mask_indices" + }, + ) + no_mask_channel_overlap: bool = field( + default=False, + metadata={"help": "whether to allow channel masks to overlap"}, + ) + freeze_finetune_updates: int = field( + default=0, + metadata={"help": "dont finetune hubert for this many updates"}, + ) + feature_grad_mult: float = field( + default=0.0, + metadata={"help": "reset feature grad mult in hubert to this"}, + ) + layerdrop: float = field( + default=0.0, + metadata={"help": "probability of dropping a layer in hubert"}, + ) + normalize: bool = II("task.normalize") + data: str = II("task.data") + + # this holds the loaded hubert args + w2v_args: Any = None + + +@dataclass +class AVHubertCtcConfig(AVHubertAsrConfig): + pass + + +@register_model("av_hubert_ctc", dataclass=AVHubertCtcConfig) +class AVHubertCtc(BaseFairseqModel): + def __init__(self, cfg: AVHubertCtcConfig, w2v_encoder: BaseFairseqModel): + super().__init__() + self.cfg = cfg + self.w2v_encoder = w2v_encoder + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + @classmethod + def build_model(cls, cfg: AVHubertCtcConfig, task: FairseqTask): + """Build a new model instance.""" + w2v_encoder = HubertEncoder(cfg, task.target_dictionary) + return cls(cfg, w2v_encoder) + + def get_normalized_probs(self, net_output, log_probs): + """Get normalized probabilities (or log probs) from a net's output.""" + + logits = net_output["encoder_out"] + if log_probs: + return utils.log_softmax(logits.float(), dim=-1) + else: + return utils.softmax(logits.float(), dim=-1) + + def get_logits(self, net_output): + logits = net_output["encoder_out"] + padding = net_output["encoder_padding_mask"] + if padding is not None and padding.any(): + padding = padding.T + logits[padding][..., 0] = 0 + logits[padding][..., 1:] = float("-inf") + + return logits + + def forward(self, **kwargs): + x = self.w2v_encoder(**kwargs) + return x + + +@dataclass +class AVHubertSeq2SeqConfig(AVHubertAsrConfig): + decoder_embed_dim: int = field( + default=768, metadata={"help": "decoder embedding dimension"} + ) + decoder_ffn_embed_dim: int = field( + default=3072, metadata={"help": "decoder embedding dimension for FFN"} + ) + decoder_layers: int = field( + default=6, metadata={"help": "num of decoder layers"} + ) + decoder_layerdrop: float = field( + default=0.0, metadata={"help": "decoder layerdrop chance"} + ) + decoder_attention_heads: int = field( + default=4, metadata={"help": "num decoder attention heads"} + ) + decoder_learned_pos: bool = field( + default=False, + metadata={"help": "use learned positional embeddings in the decoder"}, + ) + decoder_normalize_before: bool = field( + default=False, + metadata={"help": "apply layernorm before each decoder block"}, + ) + no_token_positional_embeddings: bool = field( + default=False, + metadata={ + "help": "if set, disables positional embeddings " + "(outside self attention)" + }, + ) + decoder_dropout: float = field( + default=0.0, metadata={"help": "dropout probability in the decoder"} + ) + decoder_attention_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability for attention weights " + "inside the decoder" + }, + ) + decoder_activation_dropout: float = field( + default=0.0, + metadata={ + "help": "dropout probability after activation in FFN " + "inside the decoder" + }, + ) + max_target_positions: int = field( + default=2048, metadata={"help": "max target positions"} + ) + share_decoder_input_output_embed: bool = field( + default=False, + metadata={"help": "share decoder input and output embeddings"}, + ) + no_scale_embedding: bool = field(default=True, metadata={'help': 'scale embedding'}) + +class HubertEncoder(FairseqEncoder): + def __init__(self, cfg: AVHubertAsrConfig, tgt_dict=None): + self.apply_mask = cfg.apply_mask + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.w2v_path, arg_overrides + ) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + cfg.w2v_args = w2v_args + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( + w2v_args + ) + + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + w2v_args.task.data = cfg.data + + task = tasks.setup_task(w2v_args.task) + model = task.build_model(w2v_args.model) + + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + model.load_state_dict(state["model"], strict=False) + + model.remove_pretraining_modules() + + super().__init__(task.source_dictionary) + + d = model.encoder.embedding_dim + + self.w2v_model = model + + self.final_dropout = nn.Dropout(cfg.final_dropout) + self.freeze_finetune_updates = cfg.freeze_finetune_updates + self.num_updates = 0 + + if tgt_dict is not None: + self.proj = Linear(d, len(tgt_dict)) + elif getattr(cfg, "decoder_embed_dim", d) != d: + self.proj = Linear(d, cfg.decoder_embed_dim) + else: + self.proj = None + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + + def forward(self, source, padding_mask, tbc=True, **kwargs): + + w2v_args = { + "source": source, + "padding_mask": padding_mask, + "mask": self.apply_mask and self.training, + } + ft = self.freeze_finetune_updates <= self.num_updates + + with torch.no_grad() if not ft else contextlib.ExitStack(): + x, padding_mask = self.w2v_model.extract_finetune(**w2v_args) + + if tbc: + # B x T x C -> T x B x C + x = x.transpose(0, 1) + + x = self.final_dropout(x) + + if self.proj: + x = self.proj(x) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask, + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out[ + "encoder_out" + ].index_select(1, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + return encoder_out + + def max_positions(self): + """Maximum input length supported by the encoder.""" + return None + + def upgrade_state_dict_named(self, state_dict, name): + return state_dict + + +class HubertEncoderWrapper(FairseqEncoder): + def __init__(self, w2v_model): + super().__init__(None) + self.w2v_model = w2v_model + + def forward(self, source, padding_mask, **kwargs): + w2v_args = { + "source": source, + "padding_mask": padding_mask, + } + + x, padding_mask = self.w2v_model.extract_finetune(**w2v_args) + # B x T x C -> T x B x C + x = x.transpose(0, 1) #torch.Size([106, 1, 1024]) + + return { + "encoder_out": x, # T x B x C + "encoder_padding_mask": padding_mask, # B x T + "padding_mask": padding_mask + } + + def reorder_encoder_out(self, encoder_out, new_order): + if encoder_out["encoder_out"] is not None: + encoder_out["encoder_out"] = encoder_out[ + "encoder_out" + ].index_select(1, new_order) + if encoder_out["encoder_padding_mask"] is not None: + encoder_out["encoder_padding_mask"] = encoder_out[ + "encoder_padding_mask" + ].index_select(0, new_order) + if encoder_out["padding_mask"] is not None: + encoder_out["padding_mask"] = encoder_out[ + "padding_mask" + ].index_select(0, new_order) + return encoder_out + +@register_model("av_hubert_seq2seq", dataclass=AVHubertSeq2SeqConfig) +class AVHubertSeq2Seq(FairseqEncoderDecoderModel): + def __init__(self, encoder, decoder, tgt_dict, cfg): + super().__init__(encoder, decoder) + self.cfg = cfg + self.freeze_finetune_updates = cfg.freeze_finetune_updates + + @classmethod + def build_model(cls, cfg, task): + """Build a new model instance.""" + + arg_overrides = { + "dropout": cfg.dropout, + "activation_dropout": cfg.activation_dropout, + "dropout_input": cfg.dropout_input, + "attention_dropout": cfg.attention_dropout, + "mask_length": cfg.mask_length, + "mask_prob": cfg.mask_prob, + "mask_selection": cfg.mask_selection, + "mask_other": cfg.mask_other, + "no_mask_overlap": cfg.no_mask_overlap, + "mask_channel_length": cfg.mask_channel_length, + "mask_channel_prob": cfg.mask_channel_prob, + "mask_channel_selection": cfg.mask_channel_selection, + "mask_channel_other": cfg.mask_channel_other, + "no_mask_channel_overlap": cfg.no_mask_channel_overlap, + "encoder_layerdrop": cfg.layerdrop, + "feature_grad_mult": cfg.feature_grad_mult, + } + + if cfg.w2v_args is None: + state = checkpoint_utils.load_checkpoint_to_cpu( + cfg.w2v_path, arg_overrides + ) + w2v_args = state.get("cfg", None) + if w2v_args is None: + w2v_args = convert_namespace_to_omegaconf(state["args"]) + cfg.w2v_args = w2v_args + else: + state = None + w2v_args = cfg.w2v_args + if isinstance(w2v_args, Namespace): + cfg.w2v_args = w2v_args = convert_namespace_to_omegaconf( + w2v_args + ) + + assert cfg.normalize == w2v_args.task.normalize, ( + "Fine-tuning works best when data normalization is the same. " + "Please check that --normalize is set or unset for " + "both pre-training and here" + ) + + w2v_args.task.data = cfg.data + + task_pretrain = tasks.setup_task(w2v_args.task) + if state is not None: + task_pretrain.load_state_dict(state['task_state']) + + encoder_ = task_pretrain.build_model(w2v_args.model) + + encoder = HubertEncoderWrapper(encoder_) + if state is not None and not cfg.no_pretrained_weights: + # set strict=False because we omit some modules + del state['model']['mask_emb'] + encoder.w2v_model.load_state_dict(state["model"], strict=False) + + encoder.w2v_model.remove_pretraining_modules() + + src_dict, tgt_dict = task.source_dictionary, task.target_dictionary + + def build_embedding(dictionary, embed_dim): + num_embeddings = len(dictionary) + padding_idx = dictionary.pad() + emb = Embedding(num_embeddings, embed_dim, padding_idx=padding_idx) + return emb + + decoder_embed_tokens = build_embedding(tgt_dict, cfg.decoder_embed_dim) + decoder = TransformerDecoder(cfg, tgt_dict, decoder_embed_tokens) + + return AVHubertSeq2Seq(encoder, decoder, tgt_dict, cfg) + + + def forward(self, **kwargs): + # ft = self.freeze_finetune_updates <= self.num_updates + # with torch.no_grad() if not ft else contextlib.ExitStack(): + # output = self.encoder(**kwargs) + with torch.no_grad(): + output = self.encoder(**kwargs) #encoder_out,encoder_padding_mask,padding_mask + # decoder_out = self.decoder(prev_output_tokens=kwargs['prev_output_tokens'], encoder_out=output) + return output + + def upgrade_state_dict_named(self, state_dict, name): + super().upgrade_state_dict_named(state_dict, name) + return state_dict + + def set_num_updates(self, num_updates): + """Set the number of parameters updates.""" + super().set_num_updates(num_updates) + self.num_updates = num_updates + +def Embedding(num_embeddings, embedding_dim, padding_idx): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.0) + return m diff --git a/src/slam_llm/models/avhubert/hubert_criterion.py b/src/slam_llm/models/avhubert/hubert_criterion.py new file mode 100644 index 00000000..51b1881b --- /dev/null +++ b/src/slam_llm/models/avhubert/hubert_criterion.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import re +from dataclasses import dataclass, field +from typing import List, Optional + +import torch +import torch.nn.functional as F +from fairseq import metrics, utils +from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.dataclass import FairseqDataclass + + +@dataclass +class AVHubertCriterionConfig(FairseqDataclass): + pred_masked_weight: float = field( + default=1.0, + metadata={"help": "weight for predictive loss for masked frames"}, + ) + pred_nomask_weight: float = field( + default=0.0, + metadata={"help": "weight for predictive loss for unmasked frames"}, + ) + loss_weights: Optional[List[float]] = field( + default=None, + metadata={"help": "weights for additional loss terms (not first one)"}, + ) + log_keys: List[str] = field( + default_factory=lambda: [], + metadata={"help": "output keys to log"}, + ) + + +@register_criterion("av_hubert", dataclass=AVHubertCriterionConfig) +class AVHubertCriterion(FairseqCriterion): + def __init__(self, task, pred_masked_weight, pred_nomask_weight, loss_weights=None, log_keys=None): + super().__init__(task) + self.pred_masked_weight = pred_masked_weight + self.pred_nomask_weight = pred_nomask_weight + self.loss_weights = loss_weights + self.log_keys = [] if log_keys is None else log_keys + + def forward(self, model, sample, reduce=True, log_pred=False): + """Compute the loss for the given sample. + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(target_list=sample["target_list"], **sample["net_input"]) + loss = 0. + sample_size = 0 + logging_output = {} + reduction = "sum" if reduce else "none" + + loss_m_list = [] + logp_m_list, targ_m_list = net_output['logit_m_list'], net_output['target_m_list'] + for i, (logp_m, targ_m) in enumerate(zip(logp_m_list, targ_m_list)): + loss_m = F.cross_entropy(logp_m, targ_m, reduction=reduction) + loss_m_list.append(loss_m) + logging_output[f"loss_m_{i}"] = loss_m.detach().item() + if self.pred_masked_weight > 0: + loss += self.pred_masked_weight * sum(loss_m_list) + sample_size += targ_m_list[0].numel() + + loss_u_list = [] + logp_u_list, targ_u_list = net_output['logit_u_list'], net_output['target_u_list'] + for i, (logp_u, targ_u) in enumerate(zip(logp_u_list, targ_u_list)): + loss_u = F.cross_entropy(logp_u, targ_u, reduction=reduction) + loss_u_list.append(loss_u) + logging_output[f"loss_u_{i}"] = loss_u.detach().item() + if self.pred_nomask_weight > 0: + loss += self.pred_nomask_weight * sum(loss_u_list) + sample_size += targ_u_list[0].numel() + + if self.loss_weights is not None: + assert hasattr(model, "get_extra_losses") + extra_losses, names = model.get_extra_losses(net_output) + if torch.is_tensor(extra_losses): + extra_losses = [extra_losses] + names = [names] + if len(self.loss_weights) == 1 and len(extra_losses) != 1: + self.loss_weights = [self.loss_weights[0]] * len(extra_losses) + assert len(extra_losses) == len(self.loss_weights), f"{len(extra_losses)}, {len(self.loss_weights)}" + for p, n, coef in zip(extra_losses, names, self.loss_weights): + if coef != 0 and p is not None: + p = coef * p.float() * sample_size + loss += p + logging_output[f"loss_{n}"] = p.item() + + logging_output = { + "loss": loss.item() if reduce else loss, + "ntokens": sample_size, + "nsentences": sample["id"].numel(), + "sample_size": sample_size, + **logging_output, + } + + for lk in self.log_keys: + if lk in net_output: + logging_output[lk] = float((net_output[lk])) + + with torch.no_grad(): + for i, logp_m in enumerate(logp_m_list): + # corr_m, count_m = compute_correct(logp_m) + if logp_m.numel() == 0: + corr_m, count_m = 0, 0 + else: + corr_m, count_m = (logp_m.argmax(dim=-1)==targ_m_list[i]).sum().item(), len(targ_m_list[i]) + logging_output[f"correct_m_{i}"] = corr_m + logging_output[f"count_m_{i}"] = count_m + + for i, logp_u in enumerate(logp_u_list): + if logp_u.numel() == 0: + corr_u, count_u = 0, 0 + else: + corr_u, count_u = (logp_u.argmax(dim=-1)==targ_u_list[i]).sum().item(), len(targ_u_list[i]) + logging_output[f"correct_u_{i}"] = corr_u + logging_output[f"count_u_{i}"] = count_u + + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs) -> None: + """Aggregate logging outputs from data parallel training (copied from normal cross entropy).""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + ntokens = sum(log.get("ntokens", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size / math.log(2), sample_size, round=3) + if sample_size != ntokens: + metrics.log_scalar("nll_loss", loss_sum / ntokens / math.log(2), ntokens, round=3) + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["nll_loss"].avg)) + else: + metrics.log_derived("ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)) + + counts = {} + for lk in logging_outputs[0].keys(): + if lk.startswith("count_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val) + counts[lk] = val + + for lk in logging_outputs[0].keys(): + if lk.startswith("loss_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / sample_size / math.log(2), round=3) + elif lk.startswith("correct_"): + val = sum(log[lk] for log in logging_outputs) + metrics.log_scalar(lk, val / counts[re.sub("correct", "count", lk)]) + + @staticmethod + def aggregate_logging_outputs(logging_outputs): + """Aggregate logging outputs from data parallel training.""" + raise NotImplementedError() + + @staticmethod + def logging_outputs_can_be_summed() -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False diff --git a/src/slam_llm/models/avhubert/hubert_dataset.py b/src/slam_llm/models/avhubert/hubert_dataset.py new file mode 100644 index 00000000..e80895f1 --- /dev/null +++ b/src/slam_llm/models/avhubert/hubert_dataset.py @@ -0,0 +1,529 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +import logging +import os +import sys +import time +from typing import Any, List, Optional, Union + +import numpy as np + +import torch +import torch.nn.functional as F +from fairseq.data import data_utils +from fairseq.data.fairseq_dataset import FairseqDataset +from python_speech_features import logfbank +from scipy.io import wavfile + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + import utils as custom_utils + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "DEBUG").upper(), + stream=sys.stdout, + ) +else: + from . import utils as custom_utils + +logger = logging.getLogger(__name__) + + +def load_audio_visual(manifest_path, max_keep, min_keep, frame_rate, label_paths, label_rates, tol=0.1): + def is_audio_label_aligned(audio_dur, label_durs): + return all([abs(audio_dur - label_dur) max_keep: + n_long += 1 + elif (not is_seq_label) and (not is_audio_label_aligned(sz/frame_rate, dur_from_label_list[ind])): + n_unaligned += 1 + else: + video_path = items[1] + audio_path = items[2] + audio_id = items[0] + names.append((video_path, audio_path+':'+audio_id)) + inds.append(ind) + sizes.append(sz) + tot = ind + 1 + logger.info( + ( + f"max_keep={max_keep}, min_keep={min_keep}, " + f"loaded {len(names)}, skipped {n_short} short and {n_long} long and {n_unaligned} unaligned, " + f"longest-loaded={max(sizes)}, shortest-loaded={min(sizes)}" + ) + ) + return root, names, inds, tot, sizes + +def load_label(label_path, inds, tot): + with open(label_path) as f: + labels = [line.rstrip() for line in f] + assert ( + len(labels) == tot + ), f"number of labels does not match ({len(labels)} != {tot})" + labels = [labels[i] for i in inds] + return labels + + +def load_label_offset(label_path, inds, tot): + with open(label_path) as f: + code_lengths = [len(line.encode("utf-8")) for line in f] + assert ( + len(code_lengths) == tot + ), f"number of labels does not match ({len(code_lengths)} != {tot})" + offsets = list(itertools.accumulate([0] + code_lengths)) + offsets = [(offsets[i], offsets[i + 1]) for i in inds] + return offsets + + +def verify_label_lengths( + audio_sizes, + audio_rate, + label_path, + label_rate, + inds, + tot, + tol=0.1, # tolerance in seconds +): + if label_rate < 0: + logger.info(f"{label_path} is sequence label. skipped") + return + + with open(label_path) as f: + lengths = [len(line.rstrip().split()) for line in f] + assert len(lengths) == tot + lengths = [lengths[i] for i in inds] + num_invalid = 0 + for i, ind in enumerate(inds): + dur_from_audio = audio_sizes[i] / audio_rate + dur_from_label = lengths[i] / label_rate + if abs(dur_from_audio - dur_from_label) > tol: + logger.warning( + ( + f"audio and label duration differ too much " + f"(|{dur_from_audio} - {dur_from_label}| > {tol}) " + f"in line {ind+1} of {label_path}. Check if `label_rate` " + f"is correctly set (currently {label_rate}). " + f"num. of samples = {audio_sizes[i]}; " + f"label length = {lengths[i]}" + ) + ) + num_invalid += 1 + if num_invalid > 0: + logger.warning( + f"total {num_invalid} (audio, label) pairs with mismatched lengths" + ) + + +class AVHubertDataset(FairseqDataset): + def __init__( + self, + manifest_path: str, + sample_rate: float, + label_paths: List[str], + label_rates: Union[List[float], float], # -1 for sequence labels + pad_list: List[str], + eos_list: List[str], + label_processors: Optional[List[Any]] = None, + max_keep_sample_size: Optional[int] = None, + min_keep_sample_size: Optional[int] = None, + max_sample_size: Optional[int] = None, + shuffle: bool = True, + pad_audio: bool = False, + normalize: bool = False, + store_labels: bool = True, + random_crop: bool = False, + single_target: bool = False, + stack_order_audio: int=1, + skip_verify: bool=False, + image_mean: float=0, + image_std: float=1, + image_crop_size: int=88, + image_aug: bool=False, + modalities: Optional[List[str]]=None, + is_s2s=False, + noise_fn=None, + noise_prob=0, + noise_snr=0, + noise_num=1 + ): + self.label_rates = ( + [label_rates for _ in range(len(label_paths))] + if isinstance(label_rates, int) + else label_rates + ) + self.modalities = set(modalities) + self.audio_root, self.names, inds, tot, self.sizes = load_audio_visual(manifest_path, max_keep_sample_size, min_keep_sample_size, frame_rate=sample_rate, label_paths=label_paths, label_rates=self.label_rates) + self.sample_rate = sample_rate + self.stack_order_audio = stack_order_audio + self.shuffle = shuffle + self.random_crop = random_crop + + self.num_labels = len(label_paths) + self.pad_list = pad_list + self.eos_list = eos_list + self.label_processors = label_processors + self.single_target = single_target + self.store_labels = store_labels + self.is_s2s = is_s2s + self.noise_wav, self.noise_prob, self.noise_snr, self.noise_num = [ln.strip() for ln in open(noise_fn).readlines()] if noise_fn is not None else [], noise_prob, noise_snr, noise_num + + assert self.single_target == (self.label_rates[0] == -1), f"single target should be equivalent to sequence label (label_rate==-1)" + if store_labels: + self.label_list = [load_label(p, inds, tot) for p in label_paths] + else: + self.label_paths = label_paths + self.label_offsets_list = [ + load_label_offset(p, inds, tot) for p in label_paths + ] + assert ( + label_processors is None + or len(label_processors) == self.num_labels + ) + if not skip_verify: + for label_path, label_rate in zip(label_paths, self.label_rates): + verify_label_lengths(self.sizes, self.sample_rate, label_path, label_rate, inds, tot) + else: + logger.info(f"Skip label alignment verifying") + + self.max_sample_size = ( + max_sample_size if max_sample_size is not None else sys.maxsize + ) + self.pad_audio = pad_audio + self.normalize = normalize + if image_aug: + self.transform = custom_utils.Compose([ + custom_utils.Normalize( 0.0,255.0 ), + custom_utils.RandomCrop((image_crop_size, image_crop_size)), + custom_utils.HorizontalFlip(0.5), + custom_utils.Normalize(image_mean, image_std) ]) + else: + self.transform = custom_utils.Compose([ + custom_utils.Normalize( 0.0,255.0 ), + custom_utils.CenterCrop((image_crop_size, image_crop_size)), + custom_utils.Normalize(image_mean, image_std) ]) + logger.info(f"image transform: {self.transform}") + + logger.info( + f"pad_audio={pad_audio}, random_crop={random_crop}, " + f"normalize={normalize}, max_sample_size={self.max_sample_size}, " + f"seqs2seq data={self.is_s2s},") + logger.info( + f"Noise wav: {noise_fn}->{len(self.noise_wav)} wav, Prob: {self.noise_prob}, SNR: {self.noise_snr}, Number of mixture: {self.noise_num}" + ) + + def get_label(self, index, label_idx): + if self.store_labels: + label = self.label_list[label_idx][index] + else: + with open(self.label_paths[label_idx]) as f: + offset_s, offset_e = self.label_offsets_list[label_idx][index] + f.seek(offset_s) + label = f.read(offset_e - offset_s) + + if self.label_processors is not None: + label = self.label_processors[label_idx](label) + return label + + def get_labels(self, index): + return [self.get_label(index, i) for i in range(self.num_labels)] + + def load_feature(self, mix_name): + """ + Load image and audio feature + Returns: + video_feats: numpy.ndarray of shape [T, H, W, 1], audio_feats: numpy.ndarray of shape [T, F] + """ + def stacker(feats, stack_order): + """ + Concatenating consecutive audio frames + Args: + feats - numpy.ndarray of shape [T, F] + stack_order - int (number of neighboring frames to concatenate + Returns: + feats - numpy.ndarray of shape [T', F'] + """ + feat_dim = feats.shape[1] + if len(feats) % stack_order != 0: + res = stack_order - len(feats) % stack_order + res = np.zeros([res, feat_dim]).astype(feats.dtype) + feats = np.concatenate([feats, res], axis=0) + feats = feats.reshape((-1, stack_order, feat_dim)).reshape(-1, stack_order*feat_dim) + return feats + video_fn, audio_fn = mix_name + if 'video' in self.modalities: + video_feats = self.load_video(video_fn) # [T, H, W, 1] + else: + video_feats = None + if 'audio' in self.modalities: + audio_fn = audio_fn.split(':')[0] + sample_rate, wav_data = wavfile.read(audio_fn) + assert sample_rate == 16_000 and len(wav_data.shape) == 1 + if np.random.rand() < self.noise_prob: + wav_data = self.add_noise(wav_data) + audio_feats = logfbank(wav_data, samplerate=sample_rate).astype(np.float32) # [T, F] + audio_feats = stacker(audio_feats, self.stack_order_audio) # [T/stack_order_audio, F*stack_order_audio] + else: + audio_feats = None + if audio_feats is not None and video_feats is not None: + diff = len(audio_feats) - len(video_feats) + if diff < 0: + audio_feats = np.concatenate([audio_feats, np.zeros([-diff, audio_feats.shape[-1]], dtype=audio_feats.dtype)]) + elif diff > 0: + audio_feats = audio_feats[:-diff] + return video_feats, audio_feats + + def load_video(self, audio_name): + feats = custom_utils.load_video(os.path.join(self.audio_root, audio_name)) + feats = self.transform(feats) + feats = np.expand_dims(feats, axis=-1) + return feats + + def select_noise(self): + rand_indexes = np.random.randint(0, len(self.noise_wav), size=self.noise_num) + noise_wav = [] + for x in rand_indexes: + noise_wav.append(wavfile.read(self.noise_wav[x])[1].astype(np.float32)) + if self.noise_num == 1: + return noise_wav[0] + else: + min_len = min([len(x) for x in noise_wav]) + noise_wav = [x[:min_len] for x in noise_wav] + noise_wav = np.floor(np.stack(noise_wav).mean(axis=0)) + return noise_wav + + def add_noise(self, clean_wav): + clean_wav = clean_wav.astype(np.float32) + noise_wav = self.select_noise() + if type(self.noise_snr) == int or type(self.noise_snr) == float: + snr = self.noise_snr + elif type(self.noise_snr) == tuple: + snr = np.random.randint(self.noise_snr[0], self.noise_snr[1]+1) + clean_rms = np.sqrt(np.mean(np.square(clean_wav), axis=-1)) + if len(clean_wav) > len(noise_wav): + ratio = int(np.ceil(len(clean_wav)/len(noise_wav))) + noise_wav = np.concatenate([noise_wav for _ in range(ratio)]) + if len(clean_wav) < len(noise_wav): + start = 0 + noise_wav = noise_wav[start: start + len(clean_wav)] + noise_rms = np.sqrt(np.mean(np.square(noise_wav), axis=-1)) + adjusted_noise_rms = clean_rms / (10**(snr/20)) + adjusted_noise_wav = noise_wav * (adjusted_noise_rms / noise_rms) + mixed = clean_wav + adjusted_noise_wav + + #Avoid clipping noise + max_int16 = np.iinfo(np.int16).max + min_int16 = np.iinfo(np.int16).min + if mixed.max(axis=0) > max_int16 or mixed.min(axis=0) < min_int16: + if mixed.max(axis=0) >= abs(mixed.min(axis=0)): + reduction_rate = max_int16 / mixed.max(axis=0) + else : + reduction_rate = min_int16 / mixed.min(axis=0) + mixed = mixed * (reduction_rate) + mixed = mixed.astype(np.int16) + return mixed + + def __getitem__(self, index): + video_feats, audio_feats = self.load_feature(self.names[index]) + audio_feats, video_feats = torch.from_numpy(audio_feats.astype(np.float32)) if audio_feats is not None else None, torch.from_numpy(video_feats.astype(np.float32)) if video_feats is not None else None + if self.normalize and 'audio' in self.modalities: + with torch.no_grad(): + audio_feats = F.layer_norm(audio_feats, audio_feats.shape[1:]) + labels = self.get_labels(index) + fid = self.names[index][1].split(':')[1] + return {"id": index, 'fid': fid, "video_source": video_feats, 'audio_source': audio_feats, "label_list": labels} + + def __len__(self): + return len(self.sizes) + + def crop_to_max_size(self, wav, target_size, start=None): + size = len(wav) + diff = size - target_size + if diff <= 0: + return wav, 0 + # longer utterances + if start is None: + start, end = 0, target_size + if self.random_crop: + start = np.random.randint(0, diff + 1) + end = size - diff + start + else: + end = start + target_size + return wav[start:end], start + + def collater(self, samples): + samples = [s for s in samples if s["id"] is not None] + if len(samples) == 0: + return {} + + audio_source, video_source = [s["audio_source"] for s in samples], [s["video_source"] for s in samples] + if audio_source[0] is None: + audio_source = None + if video_source[0] is None: + video_source = None + if audio_source is not None: + audio_sizes = [len(s) for s in audio_source] + else: + audio_sizes = [len(s) for s in video_source] + if self.pad_audio: + audio_size = min(max(audio_sizes), self.max_sample_size) + else: + audio_size = min(min(audio_sizes), self.max_sample_size) + if audio_source is not None: + collated_audios, padding_mask, audio_starts = self.collater_audio(audio_source, audio_size) + else: + collated_audios, audio_starts = None, None + if video_source is not None: + collated_videos, padding_mask, audio_starts = self.collater_audio(video_source, audio_size, audio_starts) + else: + collated_videos = None + targets_by_label = [ + [s["label_list"][i] for s in samples] + for i in range(self.num_labels) + ] + targets_list, lengths_list, ntokens_list = self.collater_label( + targets_by_label, audio_size, audio_starts + ) + source = {"audio": collated_audios, "video": collated_videos} + net_input = {"source": source, "padding_mask": padding_mask} + batch = { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": net_input, + "utt_id": [s['fid'] for s in samples] + } + + if self.single_target: + batch["target_lengths"] = lengths_list[0] + batch["ntokens"] = ntokens_list[0] + if self.is_s2s: + batch['target'], net_input['prev_output_tokens'] = targets_list[0][0], targets_list[0][1] + else: + batch["target"] = targets_list[0] + else: + batch["target_lengths_list"] = lengths_list + batch["ntokens_list"] = ntokens_list + batch["target_list"] = targets_list + return batch + + def collater_audio(self, audios, audio_size, audio_starts=None): + audio_feat_shape = list(audios[0].shape[1:]) + collated_audios = audios[0].new_zeros([len(audios), audio_size]+audio_feat_shape) + padding_mask = ( + torch.BoolTensor(len(audios), audio_size).fill_(False) # + ) + start_known = audio_starts is not None + audio_starts = [0 for _ in audios] if not start_known else audio_starts + for i, audio in enumerate(audios): + diff = len(audio) - audio_size + if diff == 0: + collated_audios[i] = audio + elif diff < 0: + assert self.pad_audio + collated_audios[i] = torch.cat( + [audio, audio.new_full([-diff]+audio_feat_shape, 0.0)] + ) + padding_mask[i, diff:] = True + else: + collated_audios[i], audio_starts[i] = self.crop_to_max_size( + audio, audio_size, audio_starts[i] if start_known else None + ) + if len(audios[0].shape) == 2: + collated_audios = collated_audios.transpose(1, 2) # [B, T, F] -> [B, F, T] + else: + collated_audios = collated_audios.permute((0, 4, 1, 2, 3)).contiguous() # [B, T, H, W, C] -> [B, C, T, H, W] + return collated_audios, padding_mask, audio_starts + + def collater_frm_label( + self, targets, audio_size, audio_starts, label_rate, pad + ): + assert label_rate > 0 + s2f = label_rate / self.sample_rate # num label per sample + frm_starts = [int(round(s * s2f)) for s in audio_starts] + frm_size = int(round(audio_size * s2f)) + if not self.pad_audio: + rem_size = [len(t) - s for t, s in zip(targets, frm_starts)] + frm_size = min(frm_size, *rem_size) + targets = [t[s: s + frm_size] for t, s in zip(targets, frm_starts)] + logger.debug(f"audio_starts={audio_starts}") + logger.debug(f"frame_starts={frm_starts}") + logger.debug(f"frame_size={frm_size}") + + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + targets = data_utils.collate_tokens( + targets, pad_idx=pad, left_pad=False + ) + return targets, lengths, ntokens + + def collater_seq_label_s2s(self, targets, pad): + lengths = torch.LongTensor([len(t) for t in targets]) + ntokens = lengths.sum().item() + pad, eos = self.label_processors[0].dictionary.pad(), self.label_processors[0].dictionary.eos() + targets_ = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False) + prev_output_tokens = data_utils.collate_tokens(targets, pad_idx=pad, eos_idx=eos, left_pad=False, move_eos_to_beginning=True) + return (targets_, prev_output_tokens), lengths, ntokens + + def collater_label(self, targets_by_label, audio_size, audio_starts): + targets_list, lengths_list, ntokens_list = [], [], [] + itr = zip(targets_by_label, self.label_rates, self.pad_list) + for targets, label_rate, pad in itr: + if label_rate == -1: + if self.is_s2s: + targets, lengths, ntokens = self.collater_seq_label_s2s(targets, pad) + else: + targets, lengths, ntokens = self.collater_seq_label(targets, pad) + else: + targets, lengths, ntokens = self.collater_frm_label( + targets, audio_size, audio_starts, label_rate, pad + ) + targets_list.append(targets) + lengths_list.append(lengths) + ntokens_list.append(ntokens) + return targets_list, lengths_list, ntokens_list + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + if self.pad_audio: + return self.sizes[index] + return min(self.sizes[index], self.max_sample_size) + + def ordered_indices(self): + if self.shuffle: + order = [np.random.permutation(len(self))] + else: + order = [np.arange(len(self))] + + order.append(self.sizes) + return np.lexsort(order)[::-1] diff --git a/src/slam_llm/models/avhubert/hubert_pretraining.py b/src/slam_llm/models/avhubert/hubert_pretraining.py new file mode 100644 index 00000000..25fdd1ba --- /dev/null +++ b/src/slam_llm/models/avhubert/hubert_pretraining.py @@ -0,0 +1,401 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os, glob +import sys +from typing import Dict, List, Optional, Tuple + +import numpy as np + +from dataclasses import dataclass, field +from fairseq import metrics, search +from fairseq.data import Dictionary, encoders +from fairseq.dataclass.configs import FairseqDataclass +from fairseq.tasks import register_task +from fairseq.tasks.fairseq_task import FairseqTask +from omegaconf import MISSING, II +import numpy as np +from argparse import Namespace + +DBG=True if len(sys.argv) == 1 else False + +if DBG: + from hubert_dataset import AVHubertDataset + from sequence_generator import SequenceGenerator +else: + from .hubert_dataset import AVHubertDataset + from .sequence_generator import SequenceGenerator + +logger = logging.getLogger(__name__) + + +class LabelEncoder(object): + def __init__(self, dictionary: Dictionary) -> None: + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + return self.dictionary.encode_line( + label, append_eos=False, add_if_not_exist=False, + ) + +class LabelEncoderS2SToken(object): + def __init__(self, dictionary: Dictionary, bpe_tokenizer) -> None: + self.bpe_tokenizer = bpe_tokenizer + self.dictionary = dictionary + + def __call__(self, label: str) -> List[str]: + label = self.bpe_tokenizer.encode(label.lower()) + return self.dictionary.encode_line( + label, append_eos=True, add_if_not_exist=False, + ).long() + + def decode(self, tok, symbols_ignore=None): + tok = self.dictionary.string(tok, extra_symbols_to_ignore=symbols_ignore) + if self.bpe_tokenizer: + tok = self.bpe_tokenizer.decode(tok) + return tok + +@dataclass +class AVHubertPretrainingConfig(FairseqDataclass): + input_modality: str = II("task.input_modality") #?? + data: str = field( + default=MISSING, metadata={"help": "path to data directory"} + ) + labels: List[str] = field( + default_factory=lambda: ["ltr"], + metadata={ + "help": ( + "extension of the label files to load, frame-level labels for" + " pre-training, and sequence-level label for fine-tuning" + ) + }, + ) + label_dir: Optional[str] = field( + default=None, + metadata={ + "help": "if set, looks for labels in this directory instead", + }, + ) + label_rate: int = field( + default=-1, + metadata={"help": "label frame rate. -1 for sequence label"}, + ) + + sample_rate: int = field( + default=16_000, + metadata={ + "help": "target sample rate. audio files will be up/down " + "sampled to this rate" + }, + ) + normalize: bool = field( + default=False, + metadata={ + "help": "if set, normalizes input to have 0 mean and unit variance" + }, + ) + enable_padding: bool = field( + default=False, + metadata={"help": "pad shorter samples instead of cropping"}, + ) + max_sample_size: Optional[int] = field( + default=None, + metadata={"help": "max sample size to keep in training"}, + ) + min_sample_size: Optional[int] = field( + default=None, + metadata={"help": "min sample size to keep in training"}, + ) + max_trim_sample_size: Optional[int] = field( + default=II("task.max_sample_size"), + metadata={"help": "max sample size to trim to for batching"}, + ) + single_target: Optional[bool] = field( + default=False, + metadata={ + "help": "if set, AddTargetDatasets outputs same keys " + "as AddTargetDataset" + }, + ) + random_crop: Optional[bool] = field( + default=True, + metadata={"help": "always crop from the beginning if false"}, + ) + pad_audio: Optional[bool] = field( + default=False, + metadata={"help": "pad audio to the longest one in the batch if true"}, + ) + pdb: Optional[bool] = field( + default=False, + metadata={"help": "pdb"}, + ) + stack_order_audio: int = field( + default=1, + metadata={"help": "concatenate n consecutive audio frames for one step"}, + ) + skip_verify: Optional[bool] = field( + default=False, + metadata={"help": "skip verifying label-audio alignment"}, + ) + image_aug: bool = field(default=False, metadata={'help': 'image data augmentation'}) + image_crop_size: int = field( + default=88, metadata={"help": "image ROI size"}) + image_mean: float = field( + default=0.421, metadata={"help": "image mean"}) + image_std: float = field( + default=0.165, metadata={"help": "image std"}) + modalities: Optional[List[str]] = field(default_factory=lambda: ["audio", "video"], metadata={'help': 'modalities to load'}) + is_s2s: bool=field(default=False, metadata={'help': 'seq2seq fine-tuning only'}) + tokenizer_bpe_name: Optional[str] = field(default=None, metadata={'help': 'tokenizer model name'}) + tokenizer_bpe_model: Optional[str] = field(default=None, metadata={'help': 'tokenizer model path'}) + noise_wav: Optional[str] = field(default=None, metadata={'help': 'manifest of noise wav files (one wav file path per line)'}) + noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) + noise_snr: Optional[str] = field(default='0', metadata={'help': 'noise SNR in audio'}) + noise_num: int = field(default=1, metadata={'help': 'number of noise wav files to mix'}) + fine_tuning: bool = field(default=False, metadata={"help": "set to true if fine-tuning AV-Hubert"}) + +@register_task("av_hubert_pretraining", dataclass=AVHubertPretrainingConfig) +class AVHubertPretrainingTask(FairseqTask): + + cfg: AVHubertPretrainingConfig + + def __init__( + self, + cfg: AVHubertPretrainingConfig, + ) -> None: + super().__init__(cfg) + + logger.info(f"current directory is {os.getcwd()}") + logger.info(f"AVHubertPretrainingTask Config {cfg}") + + self.fine_tuning = cfg.fine_tuning + if cfg.fine_tuning: + self.state.add_factory("target_dictionary", self.load_dictionaries) + if cfg.is_s2s: + self.state.add_factory("s2s_tokenizer", self.load_tokenizer) + else: + self.state.add_factory("dictionaries", self.load_dictionaries) + + self.blank_symbol = "" + + @property + def source_dictionary(self) -> Optional[Dictionary]: + return None # self._source_dictionary + + @property + def target_dictionary(self) -> Optional[Dictionary]: + return self.state.target_dictionary # self._target_dictionary + + @property + def dictionaries(self) -> List[Dictionary]: + return self.state.dictionaries + + def load_dictionaries(self): + label_dir = self.cfg.data if self.cfg.label_dir is None else self.cfg.label_dir + dictionaries = [ + Dictionary.load(f"{label_dir}/dict.{label}.txt") + for label in self.cfg.labels + ] + return dictionaries[0] if self.cfg.fine_tuning else dictionaries + + def load_tokenizer(self): + bpe_args = Namespace(**{'bpe': self.cfg.tokenizer_bpe_name, f"{self.cfg.tokenizer_bpe_name}_model": self.cfg.tokenizer_bpe_model}) + bpe_tokenizer = encoders.build_bpe(bpe_args) + return bpe_tokenizer + + @property + def s2s_tokenizer(self): + return self.state.s2s_tokenizer + + @classmethod + def setup_task( + cls, cfg: AVHubertPretrainingConfig, **kwargs + ) -> "AVHubertPretrainingTask": + if cfg.pdb: + import pdb + pdb.set_trace() + return cls(cfg) + + def get_label_dir(self) -> str: + if self.cfg.label_dir is None: + return self.cfg.data + return self.cfg.label_dir + + def load_dataset(self, split: str, **kwargs) -> None: + manifest = f"{self.cfg.data}/{split}.tsv" + dictionaries = [self.target_dictionary] if self.fine_tuning else self.dictionaries + pad_list = [dictionary.pad() for dictionary in dictionaries] + eos_list = [dictionary.eos() for dictionary in dictionaries] + if not self.cfg.is_s2s: + procs = [LabelEncoder(dictionary) for dictionary in dictionaries] + else: + logger.info(f"Using tokenizer") + bpe_tokenizer = self.s2s_tokenizer + procs = [LabelEncoderS2SToken(dictionary, bpe_tokenizer) for dictionary in dictionaries] + paths = [ + f"{self.get_label_dir()}/{split}.{l}" for l in self.cfg.labels + ] + image_aug = self.cfg.image_aug if split == 'train' else False + noise_fn, noise_snr = f"{self.cfg.noise_wav}/{split}.tsv" if self.cfg.noise_wav is not None else None, eval(self.cfg.noise_snr) + noise_num = self.cfg.noise_num # + self.datasets[split] = AVHubertDataset( + manifest, + sample_rate=self.cfg.sample_rate, + label_paths=paths, + label_rates=self.cfg.label_rate, + pad_list=pad_list, + eos_list=eos_list, + label_processors=procs, + max_keep_sample_size=self.cfg.max_sample_size, + min_keep_sample_size=self.cfg.min_sample_size, + max_sample_size=self.cfg.max_trim_sample_size, + pad_audio=self.cfg.pad_audio, + normalize=self.cfg.normalize, + store_labels=False, + random_crop=self.cfg.random_crop, + single_target=self.cfg.single_target, + stack_order_audio=self.cfg.stack_order_audio, + skip_verify=self.cfg.skip_verify, + image_mean=self.cfg.image_mean, + image_std=self.cfg.image_std, + image_crop_size=self.cfg.image_crop_size, + image_aug=image_aug, + modalities=self.cfg.modalities, + is_s2s=self.cfg.is_s2s, + noise_fn=noise_fn, + noise_prob=self.cfg.noise_prob, + noise_snr=noise_snr, + noise_num=noise_num + ) + + def max_positions(self) -> Tuple[int, int]: + return (sys.maxsize, sys.maxsize) + + def filter_indices_by_size( + self, indices: np.array, *args, **kwargs + ) -> np.array: + return indices + + def build_generator( + self, models, args, seq_gen_cls=None, extra_gen_cls_kwargs=None, prefix_allowed_tokens_fn=None, + ): + """ + Build a :class:`~fairseq.SequenceGenerator` instance for this + task. + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + args (fairseq.dataclass.configs.GenerationConfig): + configuration object (dataclass) for generation + extra_gen_cls_kwargs (Dict[str, Any]): extra options to pass + through to SequenceGenerator + prefix_allowed_tokens_fn (Callable[[int, torch.Tensor], List[int]]): + If provided, this function constrains the beam search to + allowed tokens only at each step. The provided function + should take 2 arguments: the batch ID (`batch_id: int`) + and a unidimensional tensor of token ids (`inputs_ids: + torch.Tensor`). It has to return a `List[int]` with the + allowed tokens for the next generation step conditioned + on the previously generated tokens (`inputs_ids`) and + the batch ID (`batch_id`). This argument is useful for + constrained generation conditioned on the prefix, as + described in "Autoregressive Entity Retrieval" + (https://arxiv.org/abs/2010.00904) and + https://github.com/facebookresearch/GENRE. + """ + if getattr(args, "score_reference", False): + from fairseq.sequence_scorer import SequenceScorer + + return SequenceScorer( + self.target_dictionary, + compute_alignment=getattr(args, "print_alignment", False), + ) + + # Choose search strategy. Defaults to Beam Search. + sampling = getattr(args, "sampling", False) + sampling_topk = getattr(args, "sampling_topk", -1) + sampling_topp = getattr(args, "sampling_topp", -1.0) + diverse_beam_groups = getattr(args, "diverse_beam_groups", -1) + diverse_beam_strength = getattr(args, "diverse_beam_strength", 0.5) + match_source_len = getattr(args, "match_source_len", False) + diversity_rate = getattr(args, "diversity_rate", -1) + constrained = getattr(args, "constraints", False) + if prefix_allowed_tokens_fn is None: + prefix_allowed_tokens_fn = getattr(args, "prefix_allowed_tokens_fn", None) + if ( + sum( + int(cond) + for cond in [ + sampling, + diverse_beam_groups > 0, + match_source_len, + diversity_rate > 0, + ] + ) + > 1 + ): + raise ValueError("Provided Search parameters are mutually exclusive.") + assert sampling_topk < 0 or sampling, "--sampling-topk requires --sampling" + assert sampling_topp < 0 or sampling, "--sampling-topp requires --sampling" + + if sampling: + search_strategy = search.Sampling( + self.target_dictionary, sampling_topk, sampling_topp + ) + elif diverse_beam_groups > 0: + search_strategy = search.DiverseBeamSearch( + self.target_dictionary, diverse_beam_groups, diverse_beam_strength + ) + elif match_source_len: + # this is useful for tagging applications where the output + # length should match the input length, so we hardcode the + # length constraints for simplicity + search_strategy = search.LengthConstrainedBeamSearch( + self.target_dictionary, + min_len_a=1, + min_len_b=0, + max_len_a=1, + max_len_b=0, + ) + elif diversity_rate > -1: + search_strategy = search.DiverseSiblingsSearch( + self.target_dictionary, diversity_rate + ) + elif constrained: + search_strategy = search.LexicallyConstrainedBeamSearch( + self.target_dictionary, args.constraints + ) + elif prefix_allowed_tokens_fn: + search_strategy = search.PrefixConstrainedBeamSearch( + self.target_dictionary, prefix_allowed_tokens_fn + ) + else: + search_strategy = search.BeamSearch(self.target_dictionary) + + extra_gen_cls_kwargs = extra_gen_cls_kwargs or {} + if seq_gen_cls is None: + if getattr(args, "print_alignment", False): + seq_gen_cls = SequenceGeneratorWithAlignment + extra_gen_cls_kwargs["print_alignment"] = args.print_alignment + else: + seq_gen_cls = SequenceGenerator + + return seq_gen_cls( + models, + self.target_dictionary, + beam_size=getattr(args, "beam", 5), + max_len_a=getattr(args, "max_len_a", 0), + max_len_b=getattr(args, "max_len_b", 200), + min_len=getattr(args, "min_len", 1), + normalize_scores=(not getattr(args, "unnormalized", False)), + len_penalty=getattr(args, "lenpen", 1), + unk_penalty=getattr(args, "unkpen", 0), + temperature=getattr(args, "temperature", 1.0), + match_source_len=getattr(args, "match_source_len", False), + no_repeat_ngram_size=getattr(args, "no_repeat_ngram_size", 0), + search_strategy=search_strategy, + **extra_gen_cls_kwargs, + ) diff --git a/src/slam_llm/models/avhubert/infer_s2s.py b/src/slam_llm/models/avhubert/infer_s2s.py new file mode 100644 index 00000000..f0751e9d --- /dev/null +++ b/src/slam_llm/models/avhubert/infer_s2s.py @@ -0,0 +1,318 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import ast +from itertools import chain +import logging +import math +import os +import sys +import json +import hashlib +import editdistance +from argparse import Namespace + +import numpy as np +import torch +from fairseq import checkpoint_utils, options, tasks, utils, distributed_utils +from fairseq.dataclass.utils import convert_namespace_to_omegaconf +from fairseq.logging import progress_bar +from fairseq.logging.meters import StopwatchMeter, TimeMeter +from fairseq.models import FairseqLanguageModel +from omegaconf import DictConfig + +from pathlib import Path +import hydra +from hydra.core.config_store import ConfigStore +from fairseq.dataclass.configs import ( + CheckpointConfig, + CommonConfig, + CommonEvalConfig, + DatasetConfig, + DistributedTrainingConfig, + GenerationConfig, + FairseqDataclass, +) +from dataclasses import dataclass, field, is_dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +from omegaconf import OmegaConf + +logging.root.setLevel(logging.INFO) +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +config_path = Path(__file__).resolve().parent / "conf" + +@dataclass +class OverrideConfig(FairseqDataclass): + noise_wav: Optional[str] = field(default=None, metadata={'help': 'noise wav file'}) + noise_prob: float = field(default=0, metadata={'help': 'noise probability'}) + noise_snr: float = field(default=0, metadata={'help': 'noise SNR in audio'}) + modalities: List[str] = field(default_factory=lambda: [""], metadata={'help': 'which modality to use'}) + data: Optional[str] = field(default=None, metadata={'help': 'path to test data directory'}) + label_dir: Optional[str] = field(default=None, metadata={'help': 'path to test label directory'}) + +@dataclass +class InferConfig(FairseqDataclass): + task: Any = None + generation: GenerationConfig = GenerationConfig() + common: CommonConfig = CommonConfig() + common_eval: CommonEvalConfig = CommonEvalConfig() + checkpoint: CheckpointConfig = CheckpointConfig() + distributed_training: DistributedTrainingConfig = DistributedTrainingConfig() + dataset: DatasetConfig = DatasetConfig() + override: OverrideConfig = OverrideConfig() + is_ax: bool = field( + default=False, + metadata={ + "help": "if true, assumes we are using ax for tuning and returns a tuple for ax to consume" + }, + ) + + +def main(cfg: DictConfig): + + if isinstance(cfg, Namespace): + cfg = convert_namespace_to_omegaconf(cfg) + + assert cfg.common_eval.path is not None, "--path required for recognition!" + assert ( + not cfg.generation.sampling or cfg.generation.nbest == cfg.generation.beam + ), "--sampling requires --nbest to be equal to --beam" + + if cfg.common_eval.results_path is not None: + os.makedirs(cfg.common_eval.results_path, exist_ok=True) + output_path = os.path.join(cfg.common_eval.results_path, "decode.log") + with open(output_path, "w", buffering=1, encoding="utf-8") as h: + return _main(cfg, h) + return _main(cfg, sys.stdout) + + +def get_symbols_to_strip_from_output(generator): + if hasattr(generator, "symbols_to_strip_from_output"): + return generator.symbols_to_strip_from_output + else: + return {generator.eos, generator.pad} + +def _main(cfg, output_file): + logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=output_file, + ) + logger = logging.getLogger("hybrid.speech_recognize") + if output_file is not sys.stdout: # also print to stdout + logger.addHandler(logging.StreamHandler(sys.stdout)) + + utils.import_user_module(cfg.common) + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task([cfg.common_eval.path]) + models = [model.eval().cuda() for model in models] #!! + saved_cfg.task.modalities = cfg.override.modalities + task = tasks.setup_task(saved_cfg.task) + + task.build_tokenizer(saved_cfg.tokenizer) + task.build_bpe(saved_cfg.bpe) + + logger.info(cfg) + + # Fix seed for stochastic decoding + if cfg.common.seed is not None and not cfg.generation.no_seed_provided: + np.random.seed(cfg.common.seed) + utils.set_torch_seed(cfg.common.seed) + + use_cuda = torch.cuda.is_available() + + # Set dictionary + dictionary = task.target_dictionary + + # loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config + task.cfg.noise_prob = cfg.override.noise_prob + task.cfg.noise_snr = cfg.override.noise_snr + task.cfg.noise_wav = cfg.override.noise_wav + if cfg.override.data is not None: + task.cfg.data = cfg.override.data + if cfg.override.label_dir is not None: + task.cfg.label_dir = cfg.override.label_dir + task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task) + + lms = [None] + + # Optimize ensemble for generation + for model in chain(models, lms): + if model is None: + continue + if cfg.common.fp16: + model.half() + if use_cuda and not cfg.distributed_training.pipeline_model_parallel: + model.cuda() + model.prepare_for_inference_(cfg) + + # Load dataset (possibly sharded) + itr = task.get_batch_iterator( + dataset=task.dataset(cfg.dataset.gen_subset), + max_tokens=cfg.dataset.max_tokens, + max_sentences=cfg.dataset.batch_size, + max_positions=utils.resolve_max_positions( + task.max_positions(), *[m.max_positions() for m in models] + ), + ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=cfg.dataset.required_batch_size_multiple, + seed=cfg.common.seed, + num_shards=cfg.distributed_training.distributed_world_size, + shard_id=cfg.distributed_training.distributed_rank, + num_workers=cfg.dataset.num_workers, + data_buffer_size=cfg.dataset.data_buffer_size, + ).next_epoch_itr(shuffle=False) + progress = progress_bar.progress_bar( + itr, + log_format=cfg.common.log_format, + log_interval=cfg.common.log_interval, + default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"), + ) + + # Initialize generator + if cfg.generation.match_source_len: + logger.warning( + "The option match_source_len is not applicable to speech recognition. Ignoring it." + ) + gen_timer = StopwatchMeter() + extra_gen_cls_kwargs = { + "lm_model": lms[0], + "lm_weight": cfg.generation.lm_weight, + } + cfg.generation.score_reference = False # + save_attention_plot = cfg.generation.print_alignment is not None + cfg.generation.print_alignment = None # + generator = task.build_generator( + models, cfg.generation, extra_gen_cls_kwargs=extra_gen_cls_kwargs + ) + + def decode_fn(x): + symbols_ignore = get_symbols_to_strip_from_output(generator) + symbols_ignore.add(dictionary.pad()) + if hasattr(task.datasets[cfg.dataset.gen_subset].label_processors[0], 'decode'): + return task.datasets[cfg.dataset.gen_subset].label_processors[0].decode(x, symbols_ignore) + chars = dictionary.string(x, extra_symbols_to_ignore=symbols_ignore) + words = " ".join("".join(chars.split()).replace('|', ' ').split()) + return words + + num_sentences = 0 + has_target = True + wps_meter = TimeMeter() + result_dict = {'utt_id': [], 'ref': [], 'hypo': []} + for sample in progress: + sample = utils.move_to_cuda(sample) if use_cuda else sample + if "net_input" not in sample: + continue + + prefix_tokens = None + if cfg.generation.prefix_size > 0: + prefix_tokens = sample["target"][:, : cfg.generation.prefix_size] + + constraints = None + if "constraints" in sample: + constraints = sample["constraints"] + + gen_timer.start() + hypos = task.inference_step( + generator, + models, + sample, + prefix_tokens=prefix_tokens, + constraints=constraints, + ) + num_generated_tokens = sum(len(h[0]["tokens"]) for h in hypos) + gen_timer.stop(num_generated_tokens) + + for i in range(len(sample["id"])): + result_dict['utt_id'].append(sample['utt_id'][i]) + ref_sent = decode_fn(sample['target'][i].int().cpu()) + result_dict['ref'].append(ref_sent) + best_hypo = hypos[i][0]['tokens'].int().cpu() + hypo_str = decode_fn(best_hypo) + result_dict['hypo'].append(hypo_str) + logger.info(f"\nREF:{ref_sent}\nHYP:{hypo_str}\n") + wps_meter.update(num_generated_tokens) + progress.log({"wps": round(wps_meter.avg)}) + num_sentences += sample["nsentences"] if "nsentences" in sample else sample["id"].numel() + + logger.info("NOTE: hypothesis and token scores are output in base 2") + logger.info("Recognized {:,} utterances ({} tokens) in {:.1f}s ({:.2f} sentences/s, {:.2f} tokens/s)".format( + num_sentences, gen_timer.n, gen_timer.sum, num_sentences / gen_timer.sum, 1. / gen_timer.avg)) + + yaml_str = OmegaConf.to_yaml(cfg.generation) + fid = int(hashlib.md5(yaml_str.encode("utf-8")).hexdigest(), 16) + fid = fid % 1000000 + result_fn = f"{cfg.common_eval.results_path}/hypo-{fid}.json" + json.dump(result_dict, open(result_fn, 'w'), indent=4) + n_err, n_total = 0, 0 + assert len(result_dict['hypo']) == len(result_dict['ref']) + for hypo, ref in zip(result_dict['hypo'], result_dict['ref']): + hypo, ref = hypo.strip().split(), ref.strip().split() + n_err += editdistance.eval(hypo, ref) + n_total += len(ref) + wer = 100 * n_err / n_total + wer_fn = f"{cfg.common_eval.results_path}/wer.{fid}" + with open(wer_fn, "w") as fo: + fo.write(f"WER: {wer}\n") + fo.write(f"err / num_ref_words = {n_err} / {n_total}\n\n") + fo.write(f"{yaml_str}") + logger.info(f"WER: {wer}%") + return + + +@hydra.main(config_path=config_path, config_name="infer") +def hydra_main(cfg: InferConfig) -> Union[float, Tuple[float, Optional[float]]]: + container = OmegaConf.to_container(cfg, resolve=True, enum_to_str=True) + cfg = OmegaConf.create(container) + OmegaConf.set_struct(cfg, True) + + if cfg.common.reset_logging: + reset_logging() + + wer = float("inf") + + try: + if cfg.common.profile: + with torch.cuda.profiler.profile(): + with torch.autograd.profiler.emit_nvtx(): + distributed_utils.call_main(cfg, main) + else: + distributed_utils.call_main(cfg, main) + + except BaseException as e: # pylint: disable=broad-except + if not cfg.common.suppress_crashes: + raise + else: + logger.error("Crashed! %s", str(e)) + return + + +def cli_main() -> None: + try: + from hydra._internal.utils import ( + get_args, + ) # pylint: disable=import-outside-toplevel + + cfg_name = get_args().config_name or "infer" + except ImportError: + logger.warning("Failed to get config name from hydra args") + cfg_name = "infer" + + cs = ConfigStore.instance() + cs.store(name=cfg_name, node=InferConfig) + + for k in InferConfig.__dataclass_fields__: + if is_dataclass(InferConfig.__dataclass_fields__[k].type): + v = InferConfig.__dataclass_fields__[k].default + cs.store(name=k, node=v) + + hydra_main() # pylint: disable=no-value-for-parameter + + +if __name__ == "__main__": + cli_main() diff --git a/src/slam_llm/models/avhubert/resnet.py b/src/slam_llm/models/avhubert/resnet.py new file mode 100644 index 00000000..e584f2b2 --- /dev/null +++ b/src/slam_llm/models/avhubert/resnet.py @@ -0,0 +1,169 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import math +import torch.nn as nn +import pdb + + +logger = logging.getLogger(__name__) + +def conv3x3(in_planes, out_planes, stride=1): + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +def downsample_basic_block( inplanes, outplanes, stride ): + return nn.Sequential( + nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(outplanes), + ) + +def downsample_basic_block_v2( inplanes, outplanes, stride ): + return nn.Sequential( + nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False), + nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, bias=False), + nn.BatchNorm2d(outplanes), + ) + + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, relu_type = 'relu' ): + super(BasicBlock, self).__init__() + + assert relu_type in ['relu','prelu'] + + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes) + + if relu_type == 'relu': + self.relu1 = nn.ReLU(inplace=True) + self.relu2 = nn.ReLU(inplace=True) + elif relu_type == 'prelu': + self.relu1 = nn.PReLU(num_parameters=planes) + self.relu2 = nn.PReLU(num_parameters=planes) + else: + raise Exception('relu type not implemented') + + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes) + + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu2(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, relu_type = 'relu', gamma_zero = False, avg_pool_downsample = False): + self.inplanes = 64 + self.relu_type = relu_type + self.gamma_zero = gamma_zero + self.downsample_block = downsample_basic_block_v2 if avg_pool_downsample else downsample_basic_block + + super(ResNet, self).__init__() + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(0, math.sqrt(2. / n)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + if self.gamma_zero: + for m in self.modules(): + if isinstance(m, BasicBlock ): + m.bn2.weight.data.zero_() + + def _make_layer(self, block, planes, blocks, stride=1): + + + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = self.downsample_block( inplanes = self.inplanes, + outplanes = planes * block.expansion, + stride = stride ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, relu_type = self.relu_type)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, relu_type = self.relu_type)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.avgpool(x) + x = x.view(x.size(0), -1) + return x + +class ResEncoder(nn.Module): + def __init__(self, relu_type, weights): + super(ResEncoder, self).__init__() + self.frontend_nout = 64 + self.backend_out = 512 + frontend_relu = nn.PReLU(num_parameters=self.frontend_nout) if relu_type == 'prelu' else nn.ReLU() + self.frontend3D = nn.Sequential( + nn.Conv3d(1, self.frontend_nout, kernel_size=(5, 7, 7), stride=(1, 2, 2), padding=(2, 3, 3), bias=False), + nn.BatchNorm3d(self.frontend_nout), + frontend_relu, + nn.MaxPool3d( kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1))) + self.trunk = ResNet(BasicBlock, [2, 2, 2, 2], relu_type=relu_type) + if weights is not None: + logger.info(f"Load {weights} for resnet") + std = torch.load(weights, map_location=torch.device('cpu'))['model_state_dict'] + frontend_std, trunk_std = OrderedDict(), OrderedDict() + for key, val in std.items(): + new_key = '.'.join(key.split('.')[1:]) + if 'frontend3D' in key: + frontend_std[new_key] = val + if 'trunk' in key: + trunk_std[new_key] = val + self.frontend3D.load_state_dict(frontend_std) + self.trunk.load_state_dict(trunk_std) + + def forward(self, x): + B, C, T, H, W = x.size() + x = self.frontend3D(x) + Tnew = x.shape[2] + x = self.threeD_to_2D_tensor(x) + x = self.trunk(x) + x = x.view(B, Tnew, x.size(1)) + x = x.transpose(1, 2).contiguous() + return x + + def threeD_to_2D_tensor(self, x): + n_batch, n_channels, s_time, sx, sy = x.shape + x = x.transpose(1, 2).contiguous() + return x.reshape(n_batch*s_time, n_channels, sx, sy) diff --git a/src/slam_llm/models/avhubert/sequence_generator.py b/src/slam_llm/models/avhubert/sequence_generator.py new file mode 100644 index 00000000..eb7ac356 --- /dev/null +++ b/src/slam_llm/models/avhubert/sequence_generator.py @@ -0,0 +1,985 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Dict, List, Optional +import sys + +import torch +import torch.nn as nn +from fairseq import search, utils +from fairseq.data import data_utils +from fairseq.models import FairseqIncrementalDecoder +from torch import Tensor +from fairseq.ngram_repeat_block import NGramRepeatBlock + + +class SequenceGenerator(nn.Module): + def __init__( + self, + models, + tgt_dict, + beam_size=1, + max_len_a=0, + max_len_b=200, + max_len=0, + min_len=1, + normalize_scores=True, + len_penalty=1.0, + unk_penalty=0.0, + temperature=1.0, + match_source_len=False, + no_repeat_ngram_size=0, + search_strategy=None, + eos=None, + symbols_to_strip_from_output=None, + lm_model=None, + lm_weight=1.0, + ): + """Generates translations of a given source sentence. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models, + currently support fairseq.models.TransformerModel for scripting + beam_size (int, optional): beam width (default: 1) + max_len_a/b (int, optional): generate sequences of maximum length + ax + b, where x is the source length + max_len (int, optional): the maximum length of the generated output + (not including end-of-sentence) + min_len (int, optional): the minimum length of the generated output + (not including end-of-sentence) + normalize_scores (bool, optional): normalize scores by the length + of the output (default: True) + len_penalty (float, optional): length penalty, where <1.0 favors + shorter, >1.0 favors longer sentences (default: 1.0) + unk_penalty (float, optional): unknown word penalty, where <0 + produces more unks, >0 produces fewer (default: 0.0) + temperature (float, optional): temperature, where values + >1.0 produce more uniform samples and values <1.0 produce + sharper samples (default: 1.0) + match_source_len (bool, optional): outputs should match the source + length (default: False) + """ + super().__init__() + if isinstance(models, EnsembleModel): + self.model = models + else: + self.model = EnsembleModel(models) + self.tgt_dict = tgt_dict + self.pad = tgt_dict.pad() + self.unk = tgt_dict.unk() + self.eos = tgt_dict.eos() if eos is None else eos + self.symbols_to_strip_from_output = ( + symbols_to_strip_from_output.union({self.eos}) + if symbols_to_strip_from_output is not None + else {self.eos} + ) + self.vocab_size = len(tgt_dict) + self.beam_size = beam_size + # the max beam size is the dictionary size - 1, since we never select pad + self.beam_size = min(beam_size, self.vocab_size - 1) + self.max_len_a = max_len_a + self.max_len_b = max_len_b + self.min_len = min_len + self.max_len = max_len or self.model.max_decoder_positions() + + self.normalize_scores = normalize_scores + self.len_penalty = len_penalty + self.unk_penalty = unk_penalty + self.temperature = temperature + self.match_source_len = match_source_len + + if no_repeat_ngram_size > 0: + self.repeat_ngram_blocker = NGramRepeatBlock(no_repeat_ngram_size) + else: + self.repeat_ngram_blocker = None + + assert temperature > 0, "--temperature must be greater than 0" + + self.search = ( + search.BeamSearch(tgt_dict) if search_strategy is None else search_strategy + ) + # We only need to set src_lengths in LengthConstrainedBeamSearch. + # As a module attribute, setting it would break in multithread + # settings when the model is shared. + self.should_set_src_lengths = ( + hasattr(self.search, "needs_src_lengths") and self.search.needs_src_lengths + ) + + self.model.eval() + + self.lm_model = lm_model + self.lm_weight = lm_weight + if self.lm_model is not None: + self.lm_model.eval() + + def cuda(self): + self.model.cuda() + return self + + @torch.no_grad() + def forward( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + """Generate a batch of translations. + + Args: + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, prefix_tokens, bos_token=bos_token) + + # TODO(myleott): unused, deprecate after pytorch-translate migration + def generate_batched_itr(self, data_itr, beam_size=None, cuda=False, timer=None): + """Iterate over a batched dataset and yield individual translations. + Args: + cuda (bool, optional): use GPU for generation + timer (StopwatchMeter, optional): time generations + """ + for sample in data_itr: + s = utils.move_to_cuda(sample) if cuda else sample + if "net_input" not in s: + continue + input = s["net_input"] + # model.forward normally channels prev_output_tokens into the decoder + # separately, but SequenceGenerator directly calls model.encoder + encoder_input = { + k: v for k, v in input.items() if k != "prev_output_tokens" + } + if timer is not None: + timer.start() + with torch.no_grad(): + hypos = self.generate(encoder_input) + if timer is not None: + timer.stop(sum(len(h[0]["tokens"]) for h in hypos)) + for i, id in enumerate(s["id"].data): + # remove padding + src = utils.strip_pad(input["src_tokens"].data[i, :], self.pad) + ref = ( + utils.strip_pad(s["target"].data[i, :], self.pad) + if s["target"] is not None + else None + ) + yield id, src, ref, hypos[i] + + @torch.no_grad() + def generate(self, models, sample: Dict[str, Dict[str, Tensor]], **kwargs) -> List[List[Dict[str, Tensor]]]: + """Generate translations. Match the api of other fairseq generators. + + Args: + models (List[~fairseq.models.FairseqModel]): ensemble of models + sample (dict): batch + prefix_tokens (torch.LongTensor, optional): force decoder to begin + with these tokens + constraints (torch.LongTensor, optional): force decoder to include + the list of constraints + bos_token (int, optional): beginning of sentence token + (default: self.eos) + """ + return self._generate(sample, **kwargs) + + def _generate( + self, + sample: Dict[str, Dict[str, Tensor]], + prefix_tokens: Optional[Tensor] = None, + constraints: Optional[Tensor] = None, + bos_token: Optional[int] = None, + ): + incremental_states = torch.jit.annotate( + List[Dict[str, Dict[str, Optional[Tensor]]]], + [ + torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) + for i in range(self.model.models_size) + ], + ) + net_input = sample["net_input"] + + if "src_tokens" in net_input: + src_tokens = net_input["src_tokens"] + # length of the source text being the character length except EndOfSentence and pad + src_lengths = ( + (src_tokens.ne(self.eos) & src_tokens.ne(self.pad)).long().sum(dim=1) + ) + elif "source" in net_input: + src_tokens = net_input["source"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + elif "features" in net_input: + src_tokens = net_input["features"] + src_lengths = ( + net_input["padding_mask"].size(-1) - net_input["padding_mask"].sum(-1) + if net_input["padding_mask"] is not None + else torch.tensor(src_tokens.size(-1)).to(src_tokens) + ) + else: + raise Exception("expected src_tokens or source in net input. input keys: " + str(net_input.keys())) + + # bsz: total number of sentences in beam + # Note that src_tokens may have more than 2 dimensions (i.e. audio features) + if src_tokens['audio'] is not None: + bsz, src_len = src_tokens['audio'].size()[:2] + src_device = src_tokens['audio'].device + else: + bsz, src_len = net_input['padding_mask'].size() + src_device = src_tokens['video'].device + beam_size = self.beam_size + if constraints is not None and not self.search.supports_constraints: + raise NotImplementedError( + "Target-side constraints were provided, but search method doesn't support them" + ) + + # Initialize constraints, when active + self.search.init_constraints(constraints, beam_size) + + max_len: int = -1 + if self.match_source_len: + max_len = src_lengths.max().item() + else: + max_len = min( + int(self.max_len_a * src_len + self.max_len_b), + self.max_len - 1, + ) + assert ( + self.min_len <= max_len + ), "min_len cannot be larger than max_len, please adjust these!" + # compute the encoder output for each beam + encoder_outs = self.model.forward_encoder(net_input) + + # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores + new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) + new_order = new_order.to(src_device).long() + encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) + # ensure encoder_outs is a List. + assert encoder_outs is not None + + # initialize buffers + scores = ( + torch.zeros(bsz * beam_size, max_len + 1).to(src_device).float() + ) # +1 for eos; pad is never chosen for scoring + tokens = ( + torch.zeros(bsz * beam_size, max_len + 2) + .to(src_device) + .long() + .fill_(self.pad) + ) # +2 for eos and pad + tokens[:, 0] = self.eos if bos_token is None else bos_token + attn: Optional[Tensor] = None + + # A list that indicates candidates that should be ignored. + # For example, suppose we're sampling and have already finalized 2/5 + # samples. Then cands_to_ignore would mark 2 positions as being ignored, + # so that we only finalize the remaining 3 samples. + cands_to_ignore = ( + torch.zeros(bsz, beam_size).to(src_device).eq(-1) + ) # forward and backward-compatible False mask + + # list of completed sentences + finalized = torch.jit.annotate( + List[List[Dict[str, Tensor]]], + [torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], + ) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step + + # a boolean array indicating if the sentence at the index is finished or not + finished = [False for i in range(bsz)] + num_remaining_sent = bsz # number of sentences remaining + + # number of candidate hypos per step + cand_size = 2 * beam_size # 2 x beam size in case half are EOS + + # offset arrays for converting between different indexing schemes + bbsz_offsets = ( + (torch.arange(0, bsz) * beam_size) + .unsqueeze(1) + .type_as(tokens) + .to(src_device) + ) + cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_device) + + reorder_state: Optional[Tensor] = None + batch_idxs: Optional[Tensor] = None + + original_batch_idxs: Optional[Tensor] = None + if "id" in sample and isinstance(sample["id"], Tensor): + original_batch_idxs = sample["id"] + else: + original_batch_idxs = torch.arange(0, bsz).type_as(tokens) + + for step in range(max_len + 1): # one extra step for EOS marker + # reorder decoder internal states based on the prev choice of beams + if reorder_state is not None: + if batch_idxs is not None: + # update beam indices to take into account removed sentences + corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( + batch_idxs + ) + reorder_state.view(-1, beam_size).add_( + corr.unsqueeze(-1) * beam_size + ) + original_batch_idxs = original_batch_idxs[batch_idxs] + self.model.reorder_incremental_state(incremental_states, reorder_state) + encoder_outs = self.model.reorder_encoder_out( + encoder_outs, reorder_state + ) + + lprobs, avg_attn_scores = self.model.forward_decoder( + tokens[:, : step + 1], + encoder_outs, + incremental_states, + self.temperature, + ) + + if self.lm_model is not None: + lm_out = self.lm_model(tokens[:, : step + 1]) + probs = self.lm_model.get_normalized_probs( + lm_out, log_probs=True, sample=None + ) + probs = probs[:, -1, :] * self.lm_weight + lprobs += probs + + lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) + + lprobs[:, self.pad] = -math.inf # never select pad + lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty + + # handle max length constraint + if step >= max_len: + lprobs[:, : self.eos] = -math.inf + lprobs[:, self.eos + 1 :] = -math.inf + + # handle prefix tokens (possibly with different lengths) + if ( + prefix_tokens is not None + and step < prefix_tokens.size(1) + and step < max_len + ): + lprobs, tokens, scores = self._prefix_tokens( + step, lprobs, scores, tokens, prefix_tokens, beam_size + ) + elif step < self.min_len: + # minimum length constraint (does not apply if using prefix_tokens) + lprobs[:, self.eos] = -math.inf + + # Record attention scores, only support avg_attn_scores is a Tensor + if avg_attn_scores is not None: + if attn is None: + attn = torch.empty( + bsz * beam_size, avg_attn_scores.size(1), max_len + 2 + ).to(scores) + attn[:, :, step + 1].copy_(avg_attn_scores) + + scores = scores.type_as(lprobs) + eos_bbsz_idx = torch.empty(0).to( + tokens + ) # indices of hypothesis ending with eos (finished sentences) + eos_scores = torch.empty(0).to( + scores + ) # scores of hypothesis ending with eos (finished sentences) + + if self.should_set_src_lengths: + self.search.set_src_lengths(src_lengths) + + if self.repeat_ngram_blocker is not None: + lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) + + # Shape: (batch, cand_size) + cand_scores, cand_indices, cand_beams = self.search.step( + step, + lprobs.view(bsz, -1, self.vocab_size), + scores.view(bsz, beam_size, -1)[:, :, :step], + tokens[:, : step + 1], + original_batch_idxs, + ) + + # cand_bbsz_idx contains beam indices for the top candidate + # hypotheses, with a range of values: [0, bsz*beam_size), + # and dimensions: [bsz, cand_size] + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + + # finalize hypotheses that end in eos + # Shape of eos_mask: (batch size, beam size) + eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) + eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) + + # only consider eos when it's among the top beam_size indices + # Now we know what beam item(s) to finish + # Shape: 1d list of absolute-numbered + eos_bbsz_idx = torch.masked_select( + cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents: List[int] = [] + if eos_bbsz_idx.numel() > 0: + eos_scores = torch.masked_select( + cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] + ) + + finalized_sents = self.finalize_hypos( + step, + eos_bbsz_idx, + eos_scores, + tokens, + scores, + finalized, + finished, + beam_size, + attn, + src_lengths, + max_len, + ) + num_remaining_sent -= len(finalized_sents) + + assert num_remaining_sent >= 0 + if num_remaining_sent == 0: + break + if self.search.stop_on_max_len and step >= max_len: + break + assert step < max_len, f"{step} < {max_len}" + + # Remove finalized sentences (ones for which {beam_size} + # finished hypotheses have been generated) from the batch. + if len(finalized_sents) > 0: + new_bsz = bsz - len(finalized_sents) + + # construct batch_idxs which holds indices of batches to keep for the next pass + batch_mask = torch.ones( + bsz, dtype=torch.bool, device=cand_indices.device + ) + batch_mask[finalized_sents] = False + # TODO replace `nonzero(as_tuple=False)` after TorchScript supports it + batch_idxs = torch.arange( + bsz, device=cand_indices.device + ).masked_select(batch_mask) + + # Choose the subset of the hypothesized constraints that will continue + self.search.prune_sentences(batch_idxs) + + eos_mask = eos_mask[batch_idxs] + cand_beams = cand_beams[batch_idxs] + bbsz_offsets.resize_(new_bsz, 1) + cand_bbsz_idx = cand_beams.add(bbsz_offsets) + cand_scores = cand_scores[batch_idxs] + cand_indices = cand_indices[batch_idxs] + + if prefix_tokens is not None: + prefix_tokens = prefix_tokens[batch_idxs] + src_lengths = src_lengths[batch_idxs] + cands_to_ignore = cands_to_ignore[batch_idxs] + + scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) + if attn is not None: + attn = attn.view(bsz, -1)[batch_idxs].view( + new_bsz * beam_size, attn.size(1), -1 + ) + bsz = new_bsz + else: + batch_idxs = None + + # Set active_mask so that values > cand_size indicate eos hypos + # and values < cand_size indicate candidate active hypos. + # After, the min values per row are the top candidate active hypos + + # Rewrite the operator since the element wise or is not supported in torchscript. + + eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) + active_mask = torch.add( + eos_mask.type_as(cand_offsets) * cand_size, + cand_offsets[: eos_mask.size(1)], + ) + + # get the top beam_size active hypotheses, which are just + # the hypos with the smallest values in active_mask. + # {active_hypos} indicates which {beam_size} hypotheses + # from the list of {2 * beam_size} candidates were + # selected. Shapes: (batch size, beam size) + new_cands_to_ignore, active_hypos = torch.topk( + active_mask, k=beam_size, dim=1, largest=False + ) + + # update cands_to_ignore to ignore any finalized hypos. + cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] + # Make sure there is at least one active item for each sentence in the batch. + assert (~cands_to_ignore).any(dim=1).all() + + # update cands_to_ignore to ignore any finalized hypos + + # {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam + # can be selected more than once). + active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) + active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) + + active_bbsz_idx = active_bbsz_idx.view(-1) + active_scores = active_scores.view(-1) + + # copy tokens and scores for active hypotheses + + # Set the tokens for each beam (can select the same row more than once) + tokens[:, : step + 1] = torch.index_select( + tokens[:, : step + 1], dim=0, index=active_bbsz_idx + ) + # Select the next token for each of them + tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( + cand_indices, dim=1, index=active_hypos + ) + if step > 0: + scores[:, :step] = torch.index_select( + scores[:, :step], dim=0, index=active_bbsz_idx + ) + scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( + cand_scores, dim=1, index=active_hypos + ) + + # Update constraints based on which candidates were selected for the next beam + self.search.update_constraints(active_hypos) + + # copy attention for active hypotheses + if attn is not None: + attn[:, :, : step + 2] = torch.index_select( + attn[:, :, : step + 2], dim=0, index=active_bbsz_idx + ) + + # reorder incremental state in decoder + reorder_state = active_bbsz_idx + + # sort by score descending + for sent in range(len(finalized)): + scores = torch.tensor( + [float(elem["score"].item()) for elem in finalized[sent]] + ) + _, sorted_scores_indices = torch.sort(scores, descending=True) + finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] + finalized[sent] = torch.jit.annotate( + List[Dict[str, Tensor]], finalized[sent] + ) + return finalized + + def _prefix_tokens( + self, step: int, lprobs, scores, tokens, prefix_tokens, beam_size: int + ): + """Handle prefix tokens""" + prefix_toks = prefix_tokens[:, step].unsqueeze(-1).repeat(1, beam_size).view(-1) + prefix_lprobs = lprobs.gather(-1, prefix_toks.unsqueeze(-1)) + prefix_mask = prefix_toks.ne(self.pad) + lprobs[prefix_mask] = torch.tensor(-math.inf).to(lprobs) + lprobs[prefix_mask] = lprobs[prefix_mask].scatter( + -1, prefix_toks[prefix_mask].unsqueeze(-1), prefix_lprobs[prefix_mask] + ) + # if prefix includes eos, then we should make sure tokens and + # scores are the same across all beams + eos_mask = prefix_toks.eq(self.eos) + if eos_mask.any(): + # validate that the first beam matches the prefix + first_beam = tokens[eos_mask].view(-1, beam_size, tokens.size(-1))[ + :, 0, 1 : step + 1 + ] + eos_mask_batch_dim = eos_mask.view(-1, beam_size)[:, 0] + target_prefix = prefix_tokens[eos_mask_batch_dim][:, :step] + assert (first_beam == target_prefix).all() + + # copy tokens, scores and lprobs from the first beam to all beams + tokens = self.replicate_first_beam(tokens, eos_mask_batch_dim, beam_size) + scores = self.replicate_first_beam(scores, eos_mask_batch_dim, beam_size) + lprobs = self.replicate_first_beam(lprobs, eos_mask_batch_dim, beam_size) + return lprobs, tokens, scores + + def replicate_first_beam(self, tensor, mask, beam_size: int): + tensor = tensor.view(-1, beam_size, tensor.size(-1)) + tensor[mask] = tensor[mask][:, :1, :] + return tensor.view(-1, tensor.size(-1)) + + def finalize_hypos( + self, + step: int, + bbsz_idx, + eos_scores, + tokens, + scores, + finalized: List[List[Dict[str, Tensor]]], + finished: List[bool], + beam_size: int, + attn: Optional[Tensor], + src_lengths, + max_len: int, + ): + """Finalize hypothesis, store finalized information in `finalized`, and change `finished` accordingly. + A sentence is finalized when {beam_size} finished items have been collected for it. + + Returns number of sentences (not beam items) being finalized. + These will be removed from the batch and not processed further. + Args: + bbsz_idx (Tensor): + """ + assert bbsz_idx.numel() == eos_scores.numel() + + # clone relevant token and attention tensors. + # tokens is (batch * beam, max_len). So the index_select + # gets the newly EOS rows, then selects cols 1..{step + 2} + tokens_clone = tokens.index_select(0, bbsz_idx)[ + :, 1 : step + 2 + ] # skip the first index, which is EOS + + tokens_clone[:, step] = self.eos + attn_clone = ( + attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2] + if attn is not None + else None + ) + + # compute scores per token position + pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1] + pos_scores[:, step] = eos_scores + # convert from cumulative to per-position scores + pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1] + + # normalize sentence-level scores + if self.normalize_scores: + eos_scores /= (step + 1) ** self.len_penalty + + # cum_unfin records which sentences in the batch are finished. + # It helps match indexing between (a) the original sentences + # in the batch and (b) the current, possibly-reduced set of + # sentences. + cum_unfin: List[int] = [] + prev = 0 + for f in finished: + if f: + prev += 1 + else: + cum_unfin.append(prev) + + # The keys here are of the form "{sent}_{unfin_idx}", where + # "unfin_idx" is the index in the current (possibly reduced) + # list of sentences, and "sent" is the index in the original, + # unreduced batch + # set() is not supported in script export + sents_seen: Dict[str, Optional[Tensor]] = {} + + # For every finished beam item + for i in range(bbsz_idx.size()[0]): + idx = bbsz_idx[i] + score = eos_scores[i] + # sentence index in the current (possibly reduced) batch + unfin_idx = idx // beam_size + # sentence index in the original (unreduced) batch + sent = unfin_idx + cum_unfin[unfin_idx] + # Cannot create dict for key type '(int, int)' in torchscript. + # The workaround is to cast int to string + seen = str(sent.item()) + "_" + str(unfin_idx.item()) + if seen not in sents_seen: + sents_seen[seen] = None + + if self.match_source_len and step > src_lengths[unfin_idx]: + score = torch.tensor(-math.inf).to(score) + + # An input sentence (among those in a batch) is finished when + # beam_size hypotheses have been collected for it + if len(finalized[sent]) < beam_size: + if attn_clone is not None: + # remove padding tokens from attn scores + hypo_attn = attn_clone[i] + else: + hypo_attn = torch.empty(0) + + finalized[sent].append( + { + "tokens": tokens_clone[i], + "score": score, + "attention": hypo_attn, # src_len x tgt_len + "alignment": torch.empty(0), + "positional_scores": pos_scores[i], + } + ) + + newly_finished: List[int] = [] + + for seen in sents_seen.keys(): + # check termination conditions for this sentence + sent: int = int(float(seen.split("_")[0])) + unfin_idx: int = int(float(seen.split("_")[1])) + + if not finished[sent] and self.is_finished( + step, unfin_idx, max_len, len(finalized[sent]), beam_size + ): + finished[sent] = True + newly_finished.append(unfin_idx) + + return newly_finished + + def is_finished( + self, + step: int, + unfin_idx: int, + max_len: int, + finalized_sent_len: int, + beam_size: int, + ): + """ + Check whether decoding for a sentence is finished, which + occurs when the list of finalized sentences has reached the + beam size, or when we reach the maximum length. + """ + assert finalized_sent_len <= beam_size + if finalized_sent_len == beam_size or step == max_len: + return True + return False + + +class EnsembleModel(nn.Module): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__() + self.models_size = len(models) + # method '__len__' is not supported in ModuleList for torch script + self.single_model = models[0] + self.models = nn.ModuleList(models) + + self.has_incremental: bool = False + if all( + hasattr(m, "decoder") and isinstance(m.decoder, FairseqIncrementalDecoder) + for m in models + ): + self.has_incremental = True + + def forward(self): + pass + + def has_encoder(self): + return hasattr(self.single_model, "encoder") + + def has_incremental_states(self): + return self.has_incremental + + def max_decoder_positions(self): + return min([m.max_decoder_positions() for m in self.models if hasattr(m, "max_decoder_positions")] + [sys.maxsize]) + + @torch.jit.export + def forward_encoder(self, net_input: Dict[str, Tensor]): + if not self.has_encoder(): + return None + return [model.encoder.forward_torchscript(net_input) for model in self.models] + + @torch.jit.export + def forward_decoder( + self, + tokens, + encoder_outs: List[Dict[str, List[Tensor]]], + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + temperature: float = 1.0, + ): + log_probs = [] + avg_attn: Optional[Tensor] = None + encoder_out: Optional[Dict[str, List[Tensor]]] = None + for i, model in enumerate(self.models): + if self.has_encoder(): + encoder_out = encoder_outs[i] + # decode each model + if self.has_incremental_states(): + decoder_out = model.decoder.forward( + tokens, + encoder_out=encoder_out, + incremental_state=incremental_states[i], + ) + else: + if hasattr(model, "decoder"): + decoder_out = model.decoder.forward(tokens, encoder_out=encoder_out) + else: + decoder_out = model.forward(tokens) + + attn: Optional[Tensor] = None + decoder_len = len(decoder_out) + if decoder_len > 1 and decoder_out[1] is not None: + if isinstance(decoder_out[1], Tensor): + attn = decoder_out[1] + else: + attn_holder = decoder_out[1]["attn"] + if isinstance(attn_holder, Tensor): + attn = attn_holder + elif attn_holder is not None: + attn = attn_holder[0] + if attn is not None: + attn = attn[:, -1, :] + + decoder_out_tuple = ( + decoder_out[0][:, -1:, :].div_(temperature), + None if decoder_len <= 1 else decoder_out[1], + ) + probs = model.get_normalized_probs( + decoder_out_tuple, log_probs=True, sample=None + ) + probs = probs[:, -1, :] + if self.models_size == 1: + return probs, attn + + log_probs.append(probs) + if attn is not None: + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + + avg_probs = torch.logsumexp(torch.stack(log_probs, dim=0), dim=0) - math.log( + self.models_size + ) + + if avg_attn is not None: + avg_attn.div_(self.models_size) + return avg_probs, avg_attn + + @torch.jit.export + def reorder_encoder_out( + self, encoder_outs: Optional[List[Dict[str, List[Tensor]]]], new_order + ): + """ + Reorder encoder output according to *new_order*. + + Args: + encoder_out: output from the ``forward()`` method + new_order (LongTensor): desired order + + Returns: + *encoder_out* rearranged according to *new_order* + """ + new_outs: List[Dict[str, List[Tensor]]] = [] + if not self.has_encoder(): + return new_outs + for i, model in enumerate(self.models): + assert encoder_outs is not None + new_outs.append( + model.encoder.reorder_encoder_out(encoder_outs[i], new_order) + ) + return new_outs + + @torch.jit.export + def reorder_incremental_state( + self, + incremental_states: List[Dict[str, Dict[str, Optional[Tensor]]]], + new_order, + ): + if not self.has_incremental_states(): + return + for i, model in enumerate(self.models): + model.decoder.reorder_incremental_state_scripting( + incremental_states[i], new_order + ) + + +class SequenceGeneratorWithAlignment(SequenceGenerator): + def __init__( + self, models, tgt_dict, left_pad_target=False, print_alignment="hard", **kwargs + ): + """Generates translations of a given source sentence. + + Produces alignments following "Jointly Learning to Align and + Translate with Transformer Models" (Garg et al., EMNLP 2019). + + Args: + left_pad_target (bool, optional): Whether or not the + hypothesis should be left padded or not when they are + teacher forced for generating alignments. + """ + super().__init__(EnsembleModelWithAlignment(models), tgt_dict, **kwargs) + self.left_pad_target = left_pad_target + + if print_alignment == "hard": + self.extract_alignment = utils.extract_hard_alignment + elif print_alignment == "soft": + self.extract_alignment = utils.extract_soft_alignment + + @torch.no_grad() + def generate(self, models, sample, **kwargs): + finalized = super()._generate(sample, **kwargs) + + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + beam_size = self.beam_size + ( + src_tokens, + src_lengths, + prev_output_tokens, + tgt_tokens, + ) = self._prepare_batch_for_alignment(sample, finalized) + if any(getattr(m, "full_context_alignment", False) for m in self.model.models): + attn = self.model.forward_align(src_tokens, src_lengths, prev_output_tokens) + else: + attn = [ + finalized[i // beam_size][i % beam_size]["attention"].transpose(1, 0) + for i in range(bsz * beam_size) + ] + + if src_tokens.device != "cpu": + src_tokens = src_tokens.to("cpu") + tgt_tokens = tgt_tokens.to("cpu") + attn = [i.to("cpu") for i in attn] + + # Process the attn matrix to extract hard alignments. + for i in range(bsz * beam_size): + alignment = self.extract_alignment( + attn[i], src_tokens[i], tgt_tokens[i], self.pad, self.eos + ) + finalized[i // beam_size][i % beam_size]["alignment"] = alignment + return finalized + + def _prepare_batch_for_alignment(self, sample, hypothesis): + src_tokens = sample["net_input"]["src_tokens"] + bsz = src_tokens.shape[0] + src_tokens = ( + src_tokens[:, None, :] + .expand(-1, self.beam_size, -1) + .contiguous() + .view(bsz * self.beam_size, -1) + ) + src_lengths = sample["net_input"]["src_lengths"] + src_lengths = ( + src_lengths[:, None] + .expand(-1, self.beam_size) + .contiguous() + .view(bsz * self.beam_size) + ) + prev_output_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=True, + ) + tgt_tokens = data_utils.collate_tokens( + [beam["tokens"] for example in hypothesis for beam in example], + self.pad, + self.eos, + self.left_pad_target, + move_eos_to_beginning=False, + ) + return src_tokens, src_lengths, prev_output_tokens, tgt_tokens + + +class EnsembleModelWithAlignment(EnsembleModel): + """A wrapper around an ensemble of models.""" + + def __init__(self, models): + super().__init__(models) + + def forward_align(self, src_tokens, src_lengths, prev_output_tokens): + avg_attn = None + for model in self.models: + decoder_out = model(src_tokens, src_lengths, prev_output_tokens) + attn = decoder_out[1]["attn"][0] + if avg_attn is None: + avg_attn = attn + else: + avg_attn.add_(attn) + if len(self.models) > 1: + avg_attn.div_(len(self.models)) + return avg_attn diff --git a/src/slam_llm/models/avhubert/utils.py b/src/slam_llm/models/avhubert/utils.py new file mode 100644 index 00000000..60d57fa0 --- /dev/null +++ b/src/slam_llm/models/avhubert/utils.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import random +import numpy as np +from typing import Dict, List, Optional, Tuple + +def load_video(path): + for i in range(3): + try: + cap = cv2.VideoCapture(path) + frames = [] + while True: + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + frames.append(frame) + else: + break + frames = np.stack(frames) + return frames + except Exception: + print(f"failed loading {path} ({i} / 3)") + if i == 2: + raise ValueError(f"Unable to load {path}") + + +class Compose(object): + """Compose several preprocess together. + Args: + preprocess (list of ``Preprocess`` objects): list of preprocess to compose. + """ + + def __init__(self, preprocess): + self.preprocess = preprocess + + def __call__(self, sample): + for t in self.preprocess: + sample = t(sample) + return sample + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.preprocess: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class Normalize(object): + """Normalize a ndarray image with mean and standard deviation. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, frames): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor image. + """ + frames = (frames - self.mean) / self.std + return frames + + def __repr__(self): + return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std) + +class CenterCrop(object): + """Crop the given image at the center + """ + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = int(round((w - tw))/2.) + delta_h = int(round((h - th))/2.) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + +class RandomCrop(object): + """Crop the given image at the center + """ + + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = random.randint(0, w-tw) + delta_h = random.randint(0, h-th) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + +class HorizontalFlip(object): + """Flip image horizontally. + """ + + def __init__(self, flip_ratio): + self.flip_ratio = flip_ratio + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be flipped with a probability flip_ratio + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + if random.random() < self.flip_ratio: + for index in range(t): + frames[index] = cv2.flip(frames[index], 1) + return frames + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + batch_indexes, starts, ends = [], [], [] + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + vals, run_starts, run_lengths = find_runs(mask[i]) + start_indices, lengths = run_starts[vals == True], run_lengths[vals == True] + starts.append(start_indices) + ends.append(start_indices+lengths) + batch_indexes.append(np.zeros([len(start_indices)])+i) + return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64) + +def find_runs(x): + """Find runs of consecutive items in an array.""" + + # ensure array + x = np.asanyarray(x) + if x.ndim != 1: + raise ValueError('only 1D array supported') + n = x.shape[0] + + # handle empty array + if n == 0: + return np.array([]), np.array([]), np.array([]) + + else: + # find run starts + loc_run_start = np.empty(n, dtype=bool) + loc_run_start[0] = True + np.not_equal(x[:-1], x[1:], out=loc_run_start[1:]) + run_starts = np.nonzero(loc_run_start)[0] + + # find run values + run_values = x[loc_run_start] + + # find run lengths + run_lengths = np.diff(np.append(run_starts, n)) + + return run_values, run_starts, run_lengths diff --git a/src/slam_llm/models/encoder.py b/src/slam_llm/models/encoder.py index 35f361f5..0b2bbff3 100644 --- a/src/slam_llm/models/encoder.py +++ b/src/slam_llm/models/encoder.py @@ -87,16 +87,32 @@ def load(cls, model_config): def extract_features(self, source, padding_mask): return self.model.extract_features(source, padding_mask)[0] -class AVEncoder: +class AVHubertEncoder: @classmethod def load(cls, model_config): - from .AV.av_net import AVNet - avnet = AVNet(model_config) - checkpoint = torch.load(model_config.TRAIN_LRS3_MODEL_FILE) - avnet.load_state_dict(checkpoint['state_dict'],strict=False) - - return avnet + import fairseq + from .avhubert import hubert_pretraining, hubert, hubert_asr + models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) + model = models[0] + return model + +class HubertEncoder: + + @classmethod + def load(cls, model_config): + import fairseq + models, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([model_config.encoder_path]) + model = models[0] + if model_config.encoder_type == "pretrain": + pass + elif model_config.encoder_type == "finetune": + model.w2v_encoder.proj = None + model.w2v_encoder.apply_mask = False + else: + assert model_config.encoder_type in ["pretrain", "finetune"], "input_type must be one of [pretrain, finetune]" + return model + class HfTextEncoder: diff --git a/src/slam_llm/models/slam_model.py b/src/slam_llm/models/slam_model.py index 6cbab4b5..4b067108 100644 --- a/src/slam_llm/models/slam_model.py +++ b/src/slam_llm/models/slam_model.py @@ -83,9 +83,13 @@ def setup_encoder(train_config, model_config, **kwargs): if encoder_name == "wavlm": from slam_llm.models.encoder import WavLMEncoder encoder = WavLMEncoder.load(model_config) - if encoder_name == "moco_wav2vec2": - from slam_llm.models.encoder import AVEncoder - encoder = AVEncoder.load(model_config) + if encoder_name == "av_hubert": + from slam_llm.models.encoder import AVHubertEncoder + encoder = AVHubertEncoder.load(model_config) + if encoder_name == "hubert": + from slam_llm.models.encoder import HubertEncoder + encoder = HubertEncoder.load(model_config) + if "llama" in encoder_name.lower(): from slam_llm.models.encoder import HfTextEncoder encoder = HfTextEncoder.load(model_config) @@ -284,8 +288,8 @@ def forward(self, audio = kwargs.get("audio", None) audio_mask = kwargs.get("audio_mask", None) visual = kwargs.get("visual", None) - vis_len = kwargs.get("vis_len", None) - maskw2v = kwargs.get("maskw2v", False) #(FIX:MZY) False for supervised learning and inference + visual_mask = kwargs.get("visual_mask", None) + # for text encoder instruct_ids = kwargs.get("instruct_ids", None) @@ -297,7 +301,7 @@ def forward(self, en_data = kwargs.get("en", None) encoder_outs = None - if audio_mel is not None or audio is not None: + if audio_mel is not None or audio or visual is not None: if self.model_config.encoder_name == "whisper": encoder_outs = self.encoder.extract_variable_length_features(audio_mel.permute(0, 2, 1)) # bs*seq*dim if self.model_config.encoder_name == "beats": @@ -306,8 +310,18 @@ def forward(self, encoder_outs = self.encoder.model.extract_features(audio_mel.unsqueeze(dim=1), padding_mask = None, mask=False, remove_extra_tokens = False)['x'] if self.model_config.encoder_name == "wavlm": encoder_outs = self.encoder.extract_features(audio, 1 - audio_mask) #(FIX:MZY): 1-audio_mask is needed for wavlm as the padding mask - if self.model_config.encoder_name == "moco_wav2vec2": - encoder_outs , inputLenBatch, audio_mel_post_mask = self.encoder((audio, audio_mask, visual, vis_len) ,maskw2v) # bs*seq*dim + if self.model_config.encoder_name == "hubert": + results = self.encoder(source = audio, padding_mask = 1-audio_mask) + if self.model_config.encoder_type == "pretrain": + encoder_outs, audio_mel_post_mask = results["x"], results["padding_mask"] + if self.model_config.encoder_type == "finetune": + encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] + encoder_outs = encoder_outs.transpose(0, 1) + if self.model_config.encoder_name == "av_hubert": + results = self.encoder(source={'video':visual, 'audio':audio}, padding_mask=visual_mask) # bs*seq*dim + encoder_outs, audio_mel_post_mask = results["encoder_out"], results["padding_mask"] + encoder_outs = encoder_outs.transpose(0, 1) + audio_mel_post_mask = (~audio_mel_post_mask).float() if self.encoder is None: encoder_outs = audio_mel if audio_mel is not None else audio @@ -315,6 +329,8 @@ def forward(self, encoder_outs = self.encoder_projector(encoder_outs, audio_mel_post_mask) if self.model_config.encoder_projector == "linear": encoder_outs = self.encoder_projector(encoder_outs) + if self.model_config.encoder_projector == "cov1d-linear": + encoder_outs = self.encoder_projector(encoder_outs) if instruct_ids is not None: if self.encoder is not None: @@ -326,6 +342,7 @@ def forward(self, encoder_outs = self.encoder_projector(encoder_outs) + if input_ids is not None: input_ids[input_ids == -1] = 0 if isinstance(self.llm, T5ForConditionalGeneration): @@ -391,7 +408,7 @@ def generate(self, model_outputs = self.llm.generate( inputs_embeds=inputs_embeds, - max_length=kwargs.get("max_length", 200), + # max_length=kwargs.get("max_length", 200), max_new_tokens=kwargs.get("max_new_tokens", 200), num_beams=kwargs.get("num_beams", 4), do_sample=kwargs.get("do_sample", False), diff --git a/src/slam_llm/pipeline/finetune.py b/src/slam_llm/pipeline/finetune.py index 240aa945..36cdbe01 100644 --- a/src/slam_llm/pipeline/finetune.py +++ b/src/slam_llm/pipeline/finetune.py @@ -49,7 +49,8 @@ from omegaconf import DictConfig, ListConfig, OmegaConf from pathlib import Path -@hydra.main(config_name=None, version_base=None) +# @hydra.main(config_name=None, version_base=None) +@hydra.main(config_name=None) def main_hydra(cfg: DictConfig): def to_plain_list(cfg_item): if isinstance(cfg_item, ListConfig): @@ -83,11 +84,20 @@ def main(kwargs: DictConfig): kwargs.log_config, \ kwargs.dataset_config fsdp_config.use_fp16 = train_config.use_fp16 - del kwargs.train_config - del kwargs.fsdp_config - del kwargs.model_config - del kwargs.log_config - del kwargs.dataset_config + if model_config.encoder_name=="av_hubert": + OmegaConf.set_struct(kwargs,False) + del kwargs["train_config"] + del kwargs["fsdp_config"] + del kwargs["model_config"] + del kwargs["log_config"] + del kwargs["dataset_config"] + OmegaConf.set_struct(kwargs,True) + else: + del kwargs.train_config + del kwargs.fsdp_config + del kwargs.model_config + del kwargs.log_config + del kwargs.dataset_config # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): diff --git a/src/slam_llm/pipeline/inference_batch.py b/src/slam_llm/pipeline/inference_batch.py index 6eb7e52b..ba499d3c 100644 --- a/src/slam_llm/pipeline/inference_batch.py +++ b/src/slam_llm/pipeline/inference_batch.py @@ -20,7 +20,7 @@ from omegaconf import DictConfig, ListConfig, OmegaConf -@hydra.main(config_name=None, version_base=None) +@hydra.main(config_name=None) def main_hydra(cfg: DictConfig): def to_plain_list(cfg_item): if isinstance(cfg_item, ListConfig): @@ -53,12 +53,20 @@ def main(kwargs: DictConfig): kwargs.model_config, \ kwargs.log_config, \ kwargs.dataset_config - - del kwargs.train_config - del kwargs.fsdp_config - del kwargs.model_config - del kwargs.log_config - del kwargs.dataset_config + if model_config.encoder_name=="av_hubert": + OmegaConf.set_struct(kwargs,False) + del kwargs["train_config"] + del kwargs["fsdp_config"] + del kwargs["model_config"] + del kwargs["log_config"] + del kwargs["dataset_config"] + OmegaConf.set_struct(kwargs,True) + else: + del kwargs.train_config + del kwargs.fsdp_config + del kwargs.model_config + del kwargs.log_config + del kwargs.dataset_config # Set log if not os.path.exists(os.path.dirname(log_config.log_file)): diff --git a/src/slam_llm/utils/custom_utils.py b/src/slam_llm/utils/custom_utils.py new file mode 100644 index 00000000..60d57fa0 --- /dev/null +++ b/src/slam_llm/utils/custom_utils.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import cv2 +import torch +import random +import numpy as np +from typing import Dict, List, Optional, Tuple + +def load_video(path): + for i in range(3): + try: + cap = cv2.VideoCapture(path) + frames = [] + while True: + ret, frame = cap.read() + if ret: + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + frames.append(frame) + else: + break + frames = np.stack(frames) + return frames + except Exception: + print(f"failed loading {path} ({i} / 3)") + if i == 2: + raise ValueError(f"Unable to load {path}") + + +class Compose(object): + """Compose several preprocess together. + Args: + preprocess (list of ``Preprocess`` objects): list of preprocess to compose. + """ + + def __init__(self, preprocess): + self.preprocess = preprocess + + def __call__(self, sample): + for t in self.preprocess: + sample = t(sample) + return sample + + def __repr__(self): + format_string = self.__class__.__name__ + '(' + for t in self.preprocess: + format_string += '\n' + format_string += ' {0}'.format(t) + format_string += '\n)' + return format_string + + +class Normalize(object): + """Normalize a ndarray image with mean and standard deviation. + """ + + def __init__(self, mean, std): + self.mean = mean + self.std = std + + def __call__(self, frames): + """ + Args: + tensor (Tensor): Tensor image of size (C, H, W) to be normalized. + Returns: + Tensor: Normalized Tensor image. + """ + frames = (frames - self.mean) / self.std + return frames + + def __repr__(self): + return self.__class__.__name__+'(mean={0}, std={1})'.format(self.mean, self.std) + +class CenterCrop(object): + """Crop the given image at the center + """ + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = int(round((w - tw))/2.) + delta_h = int(round((h - th))/2.) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + +class RandomCrop(object): + """Crop the given image at the center + """ + + def __init__(self, size): + self.size = size + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be cropped. + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + th, tw = self.size + delta_w = random.randint(0, w-tw) + delta_h = random.randint(0, h-th) + frames = frames[:, delta_h:delta_h+th, delta_w:delta_w+tw] + return frames + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + +class HorizontalFlip(object): + """Flip image horizontally. + """ + + def __init__(self, flip_ratio): + self.flip_ratio = flip_ratio + + def __call__(self, frames): + """ + Args: + img (numpy.ndarray): Images to be flipped with a probability flip_ratio + Returns: + numpy.ndarray: Cropped image. + """ + t, h, w = frames.shape + if random.random() < self.flip_ratio: + for index in range(t): + frames[index] = cv2.flip(frames[index], 1) + return frames + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + all_num_mask = int( + # add a random number for probabilistic rounding + mask_prob * all_sz / float(mask_length) + + np.random.rand() + ) + + all_num_mask = max(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + + min_len = min([len(m) for m in mask_idcs]) + batch_indexes, starts, ends = [], [], [] + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + vals, run_starts, run_lengths = find_runs(mask[i]) + start_indices, lengths = run_starts[vals == True], run_lengths[vals == True] + starts.append(start_indices) + ends.append(start_indices+lengths) + batch_indexes.append(np.zeros([len(start_indices)])+i) + return mask, np.concatenate(starts).astype(np.int64), np.concatenate(ends).astype(np.int64), np.concatenate(batch_indexes).astype(np.int64) + +def find_runs(x): + """Find runs of consecutive items in an array.""" + + # ensure array + x = np.asanyarray(x) + if x.ndim != 1: + raise ValueError('only 1D array supported') + n = x.shape[0] + + # handle empty array + if n == 0: + return np.array([]), np.array([]), np.array([]) + + else: + # find run starts + loc_run_start = np.empty(n, dtype=bool) + loc_run_start[0] = True + np.not_equal(x[:-1], x[1:], out=loc_run_start[1:]) + run_starts = np.nonzero(loc_run_start)[0] + + # find run values + run_values = x[loc_run_start] + + # find run lengths + run_lengths = np.diff(np.append(run_starts, n)) + + return run_values, run_starts, run_lengths diff --git a/src/slam_llm/utils/dataset_utils.py b/src/slam_llm/utils/dataset_utils.py index a43a603f..87dac75e 100644 --- a/src/slam_llm/utils/dataset_utils.py +++ b/src/slam_llm/utils/dataset_utils.py @@ -34,7 +34,7 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): if not module_path.endswith(".py"): raise ValueError(f"Dataset file {module_path} is not a .py file.") - module_path = Path(module_path) + module_path = Path("/root/SLAM-LLM/"+module_path) #TODO if not module_path.is_file(): raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") diff --git a/src/slam_llm/utils/model_utils.py b/src/slam_llm/utils/model_utils.py index 1b440abe..7a6b1967 100644 --- a/src/slam_llm/utils/model_utils.py +++ b/src/slam_llm/utils/model_utils.py @@ -16,8 +16,8 @@ def get_custom_model_factory(model_config, logger): if not module_path.endswith(".py"): raise ValueError(f"Dataset file {module_path} is not a .py file.") - - module_path = Path(module_path) + + module_path = Path("/root/SLAM-LLM/"+module_path) #TODO if not module_path.is_file(): raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")