diff --git a/scripts/conf/asr_aya_lora.yaml b/scripts/conf/asr_aya_lora.yaml index e4d3a7f8..f7fb1e19 100644 --- a/scripts/conf/asr_aya_lora.yaml +++ b/scripts/conf/asr_aya_lora.yaml @@ -83,7 +83,7 @@ train_config: dataset_config: dataset: "speech_dataset" - file: "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset" + file: "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset" train_data_path: null val_data_path: null train_split: "train" diff --git a/scripts/conf/asr_vicuna_lora.yaml b/scripts/conf/asr_vicuna_lora.yaml index 123c5ba9..5a9c277b 100644 --- a/scripts/conf/asr_vicuna_lora.yaml +++ b/scripts/conf/asr_vicuna_lora.yaml @@ -83,7 +83,7 @@ train_config: dataset_config: dataset: "speech_dataset" - file: "src/slam-llm/datasets/speech_dataset.py:get_speech_dataset" + file: "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset" train_data_path: null val_data_path: null train_split: "train" diff --git a/src/slam_llm/datasets/speech_dataset.py b/src/slam_llm/datasets/speech_dataset.py index 17922ddb..a863b6ec 100644 --- a/src/slam_llm/datasets/speech_dataset.py +++ b/src/slam_llm/datasets/speech_dataset.py @@ -132,7 +132,7 @@ def __getitem__(self, index): "target": target, } - answer = self.answer_template.format(target.lower()) + answer = self.answer_template.format(target) example = prompt + answer # FIX(MZY): avoid putting a bos token before answer. example_ids = self.tokenizer.encode(example) # [prompt,answer] example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos] diff --git a/src/slam_llm/model_checkpointing/checkpoint_handler.py b/src/slam_llm/model_checkpointing/checkpoint_handler.py index 4970bd68..b2a1ee7d 100644 --- a/src/slam_llm/model_checkpointing/checkpoint_handler.py +++ b/src/slam_llm/model_checkpointing/checkpoint_handler.py @@ -164,9 +164,9 @@ def save_model_checkpoint( logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") -def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): +def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0, step=0): logger.info(f"--> saving model ...") - save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1)) + save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1), str(step+1)) os.makedirs(save_dir, exist_ok=True) if not cfg.freeze_llm: if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP @@ -191,7 +191,7 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0): torch.save(encoder_dict, save_full_path) logger.info(f"encoder saved at {save_full_path}") - logger.info(f"model checkpoint saved for epoch {epoch+1}\n") + logger.info(f"model checkpoint saved for epoch {epoch+1} step {step+1}\n") def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0): with FSDP.state_dict_type( diff --git a/src/slam_llm/utils/config_utils.py b/src/slam_llm/utils/config_utils.py index fe12fe17..743b3734 100644 --- a/src/slam_llm/utils/config_utils.py +++ b/src/slam_llm/utils/config_utils.py @@ -16,7 +16,6 @@ # from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler -from slam_llm.utils.dataset_utils import DATASET_PREPROC from omegaconf import OmegaConf @@ -66,18 +65,6 @@ def generate_peft_config(train_config): return peft_config -# def generate_dataset_config(train_config, kwargs): -# names = tuple(DATASET_PREPROC.keys()) -# -# assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" -# -# dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() -# -# update_config(dataset_config, **kwargs) -# -# return dataset_config - - def get_dataloader_kwargs(train_config, dataset, tokenizer, mode): kwargs = {} batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size diff --git a/src/slam_llm/utils/dataset_utils.py b/src/slam_llm/utils/dataset_utils.py index 8a45c686..a43a603f 100644 --- a/src/slam_llm/utils/dataset_utils.py +++ b/src/slam_llm/utils/dataset_utils.py @@ -49,8 +49,6 @@ def get_custom_dataset(dataset_config, tokenizer, split: str): def get_preprocessed_dataset( tokenizer, dataset_config, split: str = "train" ) -> torch.utils.data.Dataset: - if not dataset_config.dataset in DATASET_PREPROC: - raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented") def get_split(): return ( diff --git a/src/slam_llm/utils/train_utils.py b/src/slam_llm/utils/train_utils.py index 019930fd..6a64726e 100644 --- a/src/slam_llm/utils/train_utils.py +++ b/src/slam_llm/utils/train_utils.py @@ -182,19 +182,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: if rank==0: save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch + model, optimizer, rank, train_config, epoch=epoch, step=step ) dist.barrier() elif train_config.enable_ddp: if rank==0: save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch + model, optimizer, rank, train_config, epoch=epoch, step=step ) dist.barrier() else: # model.save_pretrained(train_config.output_dir) save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch + model, optimizer, rank, train_config, epoch=epoch, step=step ) if train_config.enable_fsdp or train_config.enable_ddp: if rank==0: @@ -212,18 +212,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD: if rank==0: save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch + model, optimizer, rank, train_config, epoch=epoch, step=step ) dist.barrier() elif train_config.enable_ddp: if rank==0: save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch + model, optimizer, rank, train_config, epoch=epoch, step=step ) dist.barrier() else: save_model_checkpoint_peft( - model, optimizer, rank, train_config, epoch=epoch + model, optimizer, rank, train_config, epoch=epoch, step=step ) else: