Skip to content

Commit

Permalink
fix a minor bug in dataset config
Browse files Browse the repository at this point in the history
  • Loading branch information
ddlBoJack committed Mar 14, 2024
1 parent efc3494 commit 712cb87
Show file tree
Hide file tree
Showing 7 changed files with 12 additions and 27 deletions.
2 changes: 1 addition & 1 deletion scripts/conf/asr_aya_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ train_config:

dataset_config:
dataset: "speech_dataset"
file: "src/llama_recipes/datasets/speech_dataset.py:get_speech_dataset"
file: "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: null
val_data_path: null
train_split: "train"
Expand Down
2 changes: 1 addition & 1 deletion scripts/conf/asr_vicuna_lora.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ train_config:

dataset_config:
dataset: "speech_dataset"
file: "src/slam-llm/datasets/speech_dataset.py:get_speech_dataset"
file: "src/slam_llm/datasets/speech_dataset.py:get_speech_dataset"
train_data_path: null
val_data_path: null
train_split: "train"
Expand Down
2 changes: 1 addition & 1 deletion src/slam_llm/datasets/speech_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __getitem__(self, index):
"target": target,
}

answer = self.answer_template.format(target.lower())
answer = self.answer_template.format(target)
example = prompt + answer # FIX(MZY): avoid putting a bos token before answer.
example_ids = self.tokenizer.encode(example) # [prompt,answer]
example_ids.append(self.tokenizer.eos_token_id) # [prompt,answer,eos]
Expand Down
6 changes: 3 additions & 3 deletions src/slam_llm/model_checkpointing/checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,9 @@ def save_model_checkpoint(

logger.info(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n")

def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0, step=0):
logger.info(f"--> saving model ...")
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1))
save_dir = os.path.join(cfg.output_dir, cfg.model_name, str(epoch+1), str(step+1))
os.makedirs(save_dir, exist_ok=True)
if not cfg.freeze_llm:
if hasattr(model, "module"): #(FIX:MZY): a hack to deal with the model wrapped in DDP
Expand All @@ -191,7 +191,7 @@ def save_model_checkpoint_peft(model, optimizer, rank, cfg, epoch=0):
torch.save(encoder_dict, save_full_path)
logger.info(f"encoder saved at {save_full_path}")

logger.info(f"model checkpoint saved for epoch {epoch+1}\n")
logger.info(f"model checkpoint saved for epoch {epoch+1} step {step+1}\n")

def save_model_checkpoint_peft_full_shard(model, optimizer, rank, cfg, epoch=0):
with FSDP.state_dict_type(
Expand Down
13 changes: 0 additions & 13 deletions src/slam_llm/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

# from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
from slam_llm.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
from slam_llm.utils.dataset_utils import DATASET_PREPROC

from omegaconf import OmegaConf

Expand Down Expand Up @@ -66,18 +65,6 @@ def generate_peft_config(train_config):
return peft_config


# def generate_dataset_config(train_config, kwargs):
# names = tuple(DATASET_PREPROC.keys())
#
# assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
#
# dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
#
# update_config(dataset_config, **kwargs)
#
# return dataset_config


def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
Expand Down
2 changes: 0 additions & 2 deletions src/slam_llm/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
def get_preprocessed_dataset(
tokenizer, dataset_config, split: str = "train"
) -> torch.utils.data.Dataset:
if not dataset_config.dataset in DATASET_PREPROC:
raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented")

def get_split():
return (
Expand Down
12 changes: 6 additions & 6 deletions src/slam_llm/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,19 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
model, optimizer, rank, train_config, epoch=epoch, step=step
)
dist.barrier()
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
model, optimizer, rank, train_config, epoch=epoch, step=step
)
dist.barrier()
else:
# model.save_pretrained(train_config.output_dir)
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
model, optimizer, rank, train_config, epoch=epoch, step=step
)
if train_config.enable_fsdp or train_config.enable_ddp:
if rank==0:
Expand All @@ -212,18 +212,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
elif fsdp_config.sharding_strategy == ShardingStrategy.NO_SHARD:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
model, optimizer, rank, train_config, epoch=epoch, step=step
)
dist.barrier()
elif train_config.enable_ddp:
if rank==0:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
model, optimizer, rank, train_config, epoch=epoch, step=step
)
dist.barrier()
else:
save_model_checkpoint_peft(
model, optimizer, rank, train_config, epoch=epoch
model, optimizer, rank, train_config, epoch=epoch, step=step
)

else:
Expand Down

0 comments on commit 712cb87

Please sign in to comment.